fix organize agent's history messages without recalculating tokens (#4324)
Co-authored-by: chenyongzhao <chenyz@mama.cn>pull/4771/head
parent
74f38eacda
commit
afed3610fc
@ -0,0 +1,82 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
|
||||
|
||||
class AgentHistoryPromptTransform(PromptTransform):
|
||||
"""
|
||||
History Prompt Transform for Agent App
|
||||
"""
|
||||
def __init__(self,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage],
|
||||
history_messages: list[PromptMessage],
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.prompt_messages = prompt_messages
|
||||
self.history_messages = history_messages
|
||||
self.memory = memory
|
||||
|
||||
def get_prompt(self) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
num_system = 0
|
||||
for prompt_message in self.history_messages:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
prompt_messages.append(prompt_message)
|
||||
num_system += 1
|
||||
|
||||
if not self.memory:
|
||||
return prompt_messages
|
||||
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model,
|
||||
self.memory.model_instance.credentials,
|
||||
self.history_messages
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
|
||||
# number of prompt has been appended in current message
|
||||
num_prompt = 0
|
||||
# append prompt messages in desc order
|
||||
for prompt_message in self.history_messages[::-1]:
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
continue
|
||||
prompt_messages.append(prompt_message)
|
||||
num_prompt += 1
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model,
|
||||
self.memory.model_instance.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
prompt_messages = prompt_messages[:-num_prompt]
|
||||
break
|
||||
num_prompt = 0
|
||||
# return prompt messages in asc order
|
||||
message_prompts = prompt_messages[num_system:]
|
||||
message_prompts.reverse()
|
||||
|
||||
# merge system and message prompt
|
||||
prompt_messages = prompt_messages[:num_system]
|
||||
prompt_messages.extend(message_prompts)
|
||||
return prompt_messages
|
||||
@ -0,0 +1,77 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from models.model import Conversation
|
||||
|
||||
|
||||
def test_get_prompt():
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='System Template'),
|
||||
UserPromptMessage(content='User Query'),
|
||||
]
|
||||
history_messages = [
|
||||
SystemPromptMessage(content='System Prompt 1'),
|
||||
UserPromptMessage(content='User Prompt 1'),
|
||||
AssistantPromptMessage(content='Assistant Thought 1'),
|
||||
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
|
||||
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
|
||||
SystemPromptMessage(content='System Prompt 2'),
|
||||
UserPromptMessage(content='User Prompt 2'),
|
||||
AssistantPromptMessage(content='Assistant Thought 2'),
|
||||
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
|
||||
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
|
||||
UserPromptMessage(content='User Prompt 3'),
|
||||
AssistantPromptMessage(content='Assistant Thought 3'),
|
||||
]
|
||||
|
||||
# use message number instead of token for testing
|
||||
def side_effect_get_num_tokens(*args):
|
||||
return len(args[2])
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
|
||||
|
||||
provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.model = 'openai'
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
|
||||
transform = AgentHistoryPromptTransform(
|
||||
model_config=model_config_mock,
|
||||
prompt_messages=prompt_messages,
|
||||
history_messages=history_messages,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
max_token_limit = 5
|
||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||
result = transform.get_prompt()
|
||||
|
||||
assert len(result) <= max_token_limit
|
||||
assert len(result) == 4
|
||||
|
||||
max_token_limit = 20
|
||||
transform._calculate_rest_token = MagicMock(return_value=max_token_limit)
|
||||
result = transform.get_prompt()
|
||||
|
||||
assert len(result) <= max_token_limit
|
||||
assert len(result) == 12
|
||||
Loading…
Reference in New Issue