Refactor/react agent (#3355)
parent
509c640a80
commit
cea107b165
@ -0,0 +1,71 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessage,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class CotChatAgentRunner(CotAgentRunner):
|
||||||
|
def _organize_system_prompt(self) -> SystemPromptMessage:
|
||||||
|
"""
|
||||||
|
Organize system prompt
|
||||||
|
"""
|
||||||
|
prompt_entity = self.app_config.agent.prompt
|
||||||
|
first_prompt = prompt_entity.first_prompt
|
||||||
|
|
||||||
|
system_prompt = first_prompt \
|
||||||
|
.replace("{{instruction}}", self._instruction) \
|
||||||
|
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||||
|
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||||
|
|
||||||
|
return SystemPromptMessage(content=system_prompt)
|
||||||
|
|
||||||
|
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Organize
|
||||||
|
"""
|
||||||
|
# organize system prompt
|
||||||
|
system_message = self._organize_system_prompt()
|
||||||
|
|
||||||
|
# organize historic prompt messages
|
||||||
|
historic_messages = self._historic_prompt_messages
|
||||||
|
|
||||||
|
# organize current assistant messages
|
||||||
|
agent_scratchpad = self._agent_scratchpad
|
||||||
|
if not agent_scratchpad:
|
||||||
|
assistant_messages = []
|
||||||
|
else:
|
||||||
|
assistant_message = AssistantPromptMessage(content='')
|
||||||
|
for unit in agent_scratchpad:
|
||||||
|
if unit.is_final():
|
||||||
|
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||||
|
else:
|
||||||
|
assistant_message.content += f"Thought: {unit.thought}\n\n"
|
||||||
|
if unit.action_str:
|
||||||
|
assistant_message.content += f"Action: {unit.action_str}\n\n"
|
||||||
|
if unit.observation:
|
||||||
|
assistant_message.content += f"Observation: {unit.observation}\n\n"
|
||||||
|
|
||||||
|
assistant_messages = [assistant_message]
|
||||||
|
|
||||||
|
# query messages
|
||||||
|
query_messages = UserPromptMessage(content=self._query)
|
||||||
|
|
||||||
|
if assistant_messages:
|
||||||
|
messages = [
|
||||||
|
system_message,
|
||||||
|
*historic_messages,
|
||||||
|
query_messages,
|
||||||
|
*assistant_messages,
|
||||||
|
UserPromptMessage(content='continue')
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
messages = [system_message, *historic_messages, query_messages]
|
||||||
|
|
||||||
|
# join all messages
|
||||||
|
return messages
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class CotCompletionAgentRunner(CotAgentRunner):
|
||||||
|
def _organize_instruction_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Organize instruction prompt
|
||||||
|
"""
|
||||||
|
prompt_entity = self.app_config.agent.prompt
|
||||||
|
first_prompt = prompt_entity.first_prompt
|
||||||
|
|
||||||
|
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||||
|
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||||
|
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||||
|
|
||||||
|
return system_prompt
|
||||||
|
|
||||||
|
def _organize_historic_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Organize historic prompt
|
||||||
|
"""
|
||||||
|
historic_prompt_messages = self._historic_prompt_messages
|
||||||
|
historic_prompt = ""
|
||||||
|
|
||||||
|
for message in historic_prompt_messages:
|
||||||
|
if isinstance(message, UserPromptMessage):
|
||||||
|
historic_prompt += f"Question: {message.content}\n\n"
|
||||||
|
elif isinstance(message, AssistantPromptMessage):
|
||||||
|
historic_prompt += message.content + "\n\n"
|
||||||
|
|
||||||
|
return historic_prompt
|
||||||
|
|
||||||
|
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||||
|
"""
|
||||||
|
Organize prompt messages
|
||||||
|
"""
|
||||||
|
# organize system prompt
|
||||||
|
system_prompt = self._organize_instruction_prompt()
|
||||||
|
|
||||||
|
# organize historic prompt messages
|
||||||
|
historic_prompt = self._organize_historic_prompt()
|
||||||
|
|
||||||
|
# organize current assistant messages
|
||||||
|
agent_scratchpad = self._agent_scratchpad
|
||||||
|
assistant_prompt = ''
|
||||||
|
for unit in agent_scratchpad:
|
||||||
|
if unit.is_final():
|
||||||
|
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||||
|
else:
|
||||||
|
assistant_prompt += f"Thought: {unit.thought}\n\n"
|
||||||
|
if unit.action_str:
|
||||||
|
assistant_prompt += f"Action: {unit.action_str}\n\n"
|
||||||
|
if unit.observation:
|
||||||
|
assistant_prompt += f"Observation: {unit.observation}\n\n"
|
||||||
|
|
||||||
|
# query messages
|
||||||
|
query_prompt = f"Question: {self._query}"
|
||||||
|
|
||||||
|
# join all messages
|
||||||
|
prompt = system_prompt \
|
||||||
|
.replace("{{historic_messages}}", historic_prompt) \
|
||||||
|
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||||
|
.replace("{{query}}", query_prompt)
|
||||||
|
|
||||||
|
return [UserPromptMessage(content=prompt)]
|
||||||
@ -0,0 +1,183 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from core.agent.entities import AgentScratchpadUnit
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||||
|
|
||||||
|
|
||||||
|
class CotAgentOutputParser:
|
||||||
|
@classmethod
|
||||||
|
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None]) -> \
|
||||||
|
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||||
|
def parse_action(json_str):
|
||||||
|
try:
|
||||||
|
action = json.loads(json_str)
|
||||||
|
action_name = None
|
||||||
|
action_input = None
|
||||||
|
|
||||||
|
for key, value in action.items():
|
||||||
|
if 'input' in key.lower():
|
||||||
|
action_input = value
|
||||||
|
else:
|
||||||
|
action_name = value
|
||||||
|
|
||||||
|
if action_name is not None and action_input is not None:
|
||||||
|
return AgentScratchpadUnit.Action(
|
||||||
|
action_name=action_name,
|
||||||
|
action_input=action_input,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return json_str or ''
|
||||||
|
except:
|
||||||
|
return json_str or ''
|
||||||
|
|
||||||
|
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||||
|
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||||
|
if not code_blocks:
|
||||||
|
return
|
||||||
|
for block in code_blocks:
|
||||||
|
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||||
|
yield parse_action(json_text)
|
||||||
|
|
||||||
|
code_block_cache = ''
|
||||||
|
code_block_delimiter_count = 0
|
||||||
|
in_code_block = False
|
||||||
|
json_cache = ''
|
||||||
|
json_quote_count = 0
|
||||||
|
in_json = False
|
||||||
|
got_json = False
|
||||||
|
|
||||||
|
action_cache = ''
|
||||||
|
action_str = 'action:'
|
||||||
|
action_idx = 0
|
||||||
|
|
||||||
|
thought_cache = ''
|
||||||
|
thought_str = 'thought:'
|
||||||
|
thought_idx = 0
|
||||||
|
|
||||||
|
for response in llm_response:
|
||||||
|
response = response.delta.message.content
|
||||||
|
if not isinstance(response, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# stream
|
||||||
|
index = 0
|
||||||
|
while index < len(response):
|
||||||
|
steps = 1
|
||||||
|
delta = response[index:index+steps]
|
||||||
|
last_character = response[index-1] if index > 0 else ''
|
||||||
|
|
||||||
|
if delta == '`':
|
||||||
|
code_block_cache += delta
|
||||||
|
code_block_delimiter_count += 1
|
||||||
|
else:
|
||||||
|
if not in_code_block:
|
||||||
|
if code_block_delimiter_count > 0:
|
||||||
|
yield code_block_cache
|
||||||
|
code_block_cache = ''
|
||||||
|
else:
|
||||||
|
code_block_cache += delta
|
||||||
|
code_block_delimiter_count = 0
|
||||||
|
|
||||||
|
if not in_code_block and not in_json:
|
||||||
|
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||||
|
if last_character not in ['\n', ' ', '']:
|
||||||
|
index += steps
|
||||||
|
yield delta
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_cache += delta
|
||||||
|
action_idx += 1
|
||||||
|
if action_idx == len(action_str):
|
||||||
|
action_cache = ''
|
||||||
|
action_idx = 0
|
||||||
|
index += steps
|
||||||
|
continue
|
||||||
|
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||||
|
action_cache += delta
|
||||||
|
action_idx += 1
|
||||||
|
if action_idx == len(action_str):
|
||||||
|
action_cache = ''
|
||||||
|
action_idx = 0
|
||||||
|
index += steps
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if action_cache:
|
||||||
|
yield action_cache
|
||||||
|
action_cache = ''
|
||||||
|
action_idx = 0
|
||||||
|
|
||||||
|
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||||
|
if last_character not in ['\n', ' ', '']:
|
||||||
|
index += steps
|
||||||
|
yield delta
|
||||||
|
continue
|
||||||
|
|
||||||
|
thought_cache += delta
|
||||||
|
thought_idx += 1
|
||||||
|
if thought_idx == len(thought_str):
|
||||||
|
thought_cache = ''
|
||||||
|
thought_idx = 0
|
||||||
|
index += steps
|
||||||
|
continue
|
||||||
|
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||||
|
thought_cache += delta
|
||||||
|
thought_idx += 1
|
||||||
|
if thought_idx == len(thought_str):
|
||||||
|
thought_cache = ''
|
||||||
|
thought_idx = 0
|
||||||
|
index += steps
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if thought_cache:
|
||||||
|
yield thought_cache
|
||||||
|
thought_cache = ''
|
||||||
|
thought_idx = 0
|
||||||
|
|
||||||
|
if code_block_delimiter_count == 3:
|
||||||
|
if in_code_block:
|
||||||
|
yield from extra_json_from_code_block(code_block_cache)
|
||||||
|
code_block_cache = ''
|
||||||
|
|
||||||
|
in_code_block = not in_code_block
|
||||||
|
code_block_delimiter_count = 0
|
||||||
|
|
||||||
|
if not in_code_block:
|
||||||
|
# handle single json
|
||||||
|
if delta == '{':
|
||||||
|
json_quote_count += 1
|
||||||
|
in_json = True
|
||||||
|
json_cache += delta
|
||||||
|
elif delta == '}':
|
||||||
|
json_cache += delta
|
||||||
|
if json_quote_count > 0:
|
||||||
|
json_quote_count -= 1
|
||||||
|
if json_quote_count == 0:
|
||||||
|
in_json = False
|
||||||
|
got_json = True
|
||||||
|
index += steps
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if in_json:
|
||||||
|
json_cache += delta
|
||||||
|
|
||||||
|
if got_json:
|
||||||
|
got_json = False
|
||||||
|
yield parse_action(json_cache)
|
||||||
|
json_cache = ''
|
||||||
|
json_quote_count = 0
|
||||||
|
in_json = False
|
||||||
|
|
||||||
|
if not in_code_block and not in_json:
|
||||||
|
yield delta.replace('`', '')
|
||||||
|
|
||||||
|
index += steps
|
||||||
|
|
||||||
|
if code_block_cache:
|
||||||
|
yield code_block_cache
|
||||||
|
|
||||||
|
if json_cache:
|
||||||
|
yield parse_action(json_cache)
|
||||||
|
|
||||||
Loading…
Reference in New Issue