Remove langchain dataset retrival agent logic (#3311)
parent
8cefa6b82e
commit
b6de97ad53
@ -1,59 +0,0 @@
|
|||||||
import time
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
||||||
from langchain.chat_models.base import SimpleChatModel
|
|
||||||
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
|
||||||
|
|
||||||
|
|
||||||
class FakeLLM(SimpleChatModel):
|
|
||||||
"""Fake ChatModel for testing purposes."""
|
|
||||||
|
|
||||||
streaming: bool = False
|
|
||||||
"""Whether to stream the results or not."""
|
|
||||||
response: str
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
return "fake-chat-model"
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""First try to lookup in queries, else return 'foo' or 'bar'."""
|
|
||||||
return self.response
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
|
||||||
return {"response": self.response}
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
messages: list[BaseMessage],
|
|
||||||
stop: Optional[list[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
||||||
if self.streaming:
|
|
||||||
for token in output_str:
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(token)
|
|
||||||
time.sleep(0.01)
|
|
||||||
|
|
||||||
message = AIMessage(content=output_str)
|
|
||||||
generation = ChatGeneration(message=message)
|
|
||||||
llm_output = {"token_usage": {
|
|
||||||
'prompt_tokens': 0,
|
|
||||||
'completion_tokens': 0,
|
|
||||||
'total_tokens': 0,
|
|
||||||
}}
|
|
||||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
|
||||||
@ -1,46 +0,0 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from langchain import LLMChain as LCLLMChain
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain.schema import Generation, LLMResult
|
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
|
||||||
from core.model_manager import ModelInstance
|
|
||||||
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(LCLLMChain):
|
|
||||||
model_config: ModelConfigWithCredentialsEntity
|
|
||||||
"""The language model instance to use."""
|
|
||||||
llm: BaseLanguageModel = FakeLLM(response="")
|
|
||||||
parameters: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_list: list[dict[str, Any]],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Generate LLM result from inputs."""
|
|
||||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
|
||||||
messages = prompts[0].to_messages()
|
|
||||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
|
||||||
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
|
||||||
model=self.model_config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = model_instance.invoke_llm(
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
stream=False,
|
|
||||||
stop=stop,
|
|
||||||
model_parameters=self.parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
generations = [
|
|
||||||
[Generation(text=result.message.content)]
|
|
||||||
]
|
|
||||||
|
|
||||||
return LLMResult(generations=generations)
|
|
||||||
@ -1,179 +0,0 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, Optional, Union
|
|
||||||
|
|
||||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
|
||||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import root_validator
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
||||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
|
||||||
from core.model_manager import ModelInstance
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
|
||||||
from core.rag.retrieval.agent.fake_llm import FakeLLM
|
|
||||||
|
|
||||||
|
|
||||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
|
||||||
"""
|
|
||||||
An Multi Dataset Retrieve Agent driven by Router.
|
|
||||||
"""
|
|
||||||
model_config: ModelConfigWithCredentialsEntity
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_llm(cls, values: dict) -> dict:
|
|
||||||
return values
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str):
|
|
||||||
"""
|
|
||||||
return should use agent
|
|
||||||
|
|
||||||
:param query:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
if len(self.tools) == 0:
|
|
||||||
return AgentFinish(return_values={"output": ''}, log='')
|
|
||||||
elif len(self.tools) == 1:
|
|
||||||
tool = next(iter(self.tools))
|
|
||||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
|
||||||
# output = ''
|
|
||||||
# rst_json = json.loads(rst)
|
|
||||||
# for item in rst_json:
|
|
||||||
# output += f'{item["content"]}\n'
|
|
||||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
|
||||||
|
|
||||||
if intermediate_steps:
|
|
||||||
_, observation = intermediate_steps[-1]
|
|
||||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
|
||||||
if isinstance(agent_decision, AgentAction):
|
|
||||||
tool_inputs = agent_decision.tool_input
|
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
|
||||||
tool_inputs['query'] = kwargs['input']
|
|
||||||
agent_decision.tool_input = tool_inputs
|
|
||||||
else:
|
|
||||||
agent_decision.return_values['output'] = ''
|
|
||||||
return agent_decision
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def real_plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
|
||||||
selected_inputs = {
|
|
||||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
|
||||||
}
|
|
||||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
|
||||||
prompt = self.prompt.format_prompt(**full_inputs)
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
|
||||||
|
|
||||||
model_instance = ModelInstance(
|
|
||||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
|
||||||
model=self.model_config.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for function in self.functions:
|
|
||||||
tool = PromptMessageTool(
|
|
||||||
**function
|
|
||||||
)
|
|
||||||
|
|
||||||
tools.append(tool)
|
|
||||||
|
|
||||||
result = model_instance.invoke_llm(
|
|
||||||
prompt_messages=prompt_messages,
|
|
||||||
tools=tools,
|
|
||||||
stream=False,
|
|
||||||
model_parameters={
|
|
||||||
'temperature': 0.2,
|
|
||||||
'top_p': 0.3,
|
|
||||||
'max_tokens': 1500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
ai_message = AIMessage(
|
|
||||||
content=result.message.content or "",
|
|
||||||
additional_kwargs={
|
|
||||||
'function_call': {
|
|
||||||
'id': result.message.tool_calls[0].id,
|
|
||||||
**result.message.tool_calls[0].function.dict()
|
|
||||||
} if result.message.tool_calls else None
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_decision = _parse_ai_message(ai_message)
|
|
||||||
return agent_decision
|
|
||||||
|
|
||||||
async def aplan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(
|
|
||||||
cls,
|
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
|
||||||
system_message: Optional[SystemMessage] = SystemMessage(
|
|
||||||
content="You are a helpful AI assistant."
|
|
||||||
),
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> BaseSingleActionAgent:
|
|
||||||
prompt = cls.create_prompt(
|
|
||||||
extra_prompt_messages=extra_prompt_messages,
|
|
||||||
system_message=system_message,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
model_config=model_config,
|
|
||||||
llm=FakeLLM(response=''),
|
|
||||||
prompt=prompt,
|
|
||||||
tools=tools,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@ -1,259 +0,0 @@
|
|||||||
import re
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any, Optional, Union, cast
|
|
||||||
|
|
||||||
from langchain import BasePromptTemplate, PromptTemplate
|
|
||||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
|
||||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
|
||||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
||||||
from core.rag.retrieval.agent.llm_chain import LLMChain
|
|
||||||
|
|
||||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
|
||||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
|
||||||
Valid "action" values: "Final Answer" or {tool_names}
|
|
||||||
|
|
||||||
Provide only ONE action per $JSON_BLOB, as shown:
|
|
||||||
|
|
||||||
```
|
|
||||||
{{{{
|
|
||||||
"action": $TOOL_NAME,
|
|
||||||
"action_input": $INPUT
|
|
||||||
}}}}
|
|
||||||
```
|
|
||||||
|
|
||||||
Follow this format:
|
|
||||||
|
|
||||||
Question: input question to answer
|
|
||||||
Thought: consider previous and subsequent steps
|
|
||||||
Action:
|
|
||||||
```
|
|
||||||
$JSON_BLOB
|
|
||||||
```
|
|
||||||
Observation: action result
|
|
||||||
... (repeat Thought/Action/Observation N times)
|
|
||||||
Thought: I know what to respond
|
|
||||||
Action:
|
|
||||||
```
|
|
||||||
{{{{
|
|
||||||
"action": "Final Answer",
|
|
||||||
"action_input": "Final response to human"
|
|
||||||
}}}}
|
|
||||||
```"""
|
|
||||||
|
|
||||||
|
|
||||||
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
|
|
||||||
dataset_tools: Sequence[BaseTool]
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str):
|
|
||||||
"""
|
|
||||||
return should use agent
|
|
||||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
|
||||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
|
||||||
|
|
||||||
:param query:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def plan(
|
|
||||||
self,
|
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""Given input, decided what to do.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
intermediate_steps: Steps the LLM has taken to date,
|
|
||||||
along with observations
|
|
||||||
callbacks: Callbacks to run.
|
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Action specifying what tool to use.
|
|
||||||
"""
|
|
||||||
if len(self.dataset_tools) == 0:
|
|
||||||
return AgentFinish(return_values={"output": ''}, log='')
|
|
||||||
elif len(self.dataset_tools) == 1:
|
|
||||||
tool = next(iter(self.dataset_tools))
|
|
||||||
rst = tool.run(tool_input={'query': kwargs['input']})
|
|
||||||
return AgentFinish(return_values={"output": rst}, log=rst)
|
|
||||||
|
|
||||||
if intermediate_steps:
|
|
||||||
_, observation = intermediate_steps[-1]
|
|
||||||
return AgentFinish(return_values={"output": observation}, log=observation)
|
|
||||||
|
|
||||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_decision = self.output_parser.parse(full_output)
|
|
||||||
if isinstance(agent_decision, AgentAction):
|
|
||||||
tool_inputs = agent_decision.tool_input
|
|
||||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
|
||||||
tool_inputs['query'] = kwargs['input']
|
|
||||||
agent_decision.tool_input = tool_inputs
|
|
||||||
elif isinstance(tool_inputs, str):
|
|
||||||
agent_decision.tool_input = kwargs['input']
|
|
||||||
else:
|
|
||||||
agent_decision.return_values['output'] = ''
|
|
||||||
return agent_decision
|
|
||||||
except OutputParserException:
|
|
||||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
|
||||||
"I don't know how to respond to that."}, "")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_prompt(
|
|
||||||
cls,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
suffix: str = SUFFIX,
|
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
|
||||||
) -> BasePromptTemplate:
|
|
||||||
tool_strings = []
|
|
||||||
for tool in tools:
|
|
||||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
|
||||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
|
||||||
formatted_tools = "\n".join(tool_strings)
|
|
||||||
unique_tool_names = set(tool.name for tool in tools)
|
|
||||||
tool_names = ", ".join('"' + name + '"' for name in unique_tool_names)
|
|
||||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
|
||||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
|
||||||
_memory_prompts = memory_prompts or []
|
|
||||||
messages = [
|
|
||||||
SystemMessagePromptTemplate.from_template(template),
|
|
||||||
*_memory_prompts,
|
|
||||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
|
||||||
]
|
|
||||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create_completion_prompt(
|
|
||||||
cls,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
) -> PromptTemplate:
|
|
||||||
"""Create prompt in the style of the zero shot agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: List of tools the agent will have access to, used to format the
|
|
||||||
prompt.
|
|
||||||
prefix: String to put before the list of tools.
|
|
||||||
input_variables: List of input variables the final prompt will expect.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A PromptTemplate with the template assembled from the pieces here.
|
|
||||||
"""
|
|
||||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
|
||||||
Question: {input}
|
|
||||||
Thought: {agent_scratchpad}
|
|
||||||
"""
|
|
||||||
|
|
||||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
|
||||||
tool_names = ", ".join([tool.name for tool in tools])
|
|
||||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
|
||||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
|
||||||
if input_variables is None:
|
|
||||||
input_variables = ["input", "agent_scratchpad"]
|
|
||||||
return PromptTemplate(template=template, input_variables=input_variables)
|
|
||||||
|
|
||||||
def _construct_scratchpad(
|
|
||||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
|
||||||
) -> str:
|
|
||||||
agent_scratchpad = ""
|
|
||||||
for action, observation in intermediate_steps:
|
|
||||||
agent_scratchpad += action.log
|
|
||||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
|
||||||
|
|
||||||
if not isinstance(agent_scratchpad, str):
|
|
||||||
raise ValueError("agent_scratchpad should be of type string.")
|
|
||||||
if agent_scratchpad:
|
|
||||||
llm_chain = cast(LLMChain, self.llm_chain)
|
|
||||||
if llm_chain.model_config.mode == "chat":
|
|
||||||
return (
|
|
||||||
f"This was your previous work "
|
|
||||||
f"(but I haven't seen any of it! I only see what "
|
|
||||||
f"you return as final answer):\n{agent_scratchpad}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
|
||||||
else:
|
|
||||||
return agent_scratchpad
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm_and_tools(
|
|
||||||
cls,
|
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
|
||||||
tools: Sequence[BaseTool],
|
|
||||||
callback_manager: Optional[BaseCallbackManager] = None,
|
|
||||||
output_parser: Optional[AgentOutputParser] = None,
|
|
||||||
prefix: str = PREFIX,
|
|
||||||
suffix: str = SUFFIX,
|
|
||||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
|
||||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
|
||||||
input_variables: Optional[list[str]] = None,
|
|
||||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Agent:
|
|
||||||
"""Construct an agent from an LLM and tools."""
|
|
||||||
cls._validate_tools(tools)
|
|
||||||
if model_config.mode == "chat":
|
|
||||||
prompt = cls.create_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=prefix,
|
|
||||||
suffix=suffix,
|
|
||||||
human_message_template=human_message_template,
|
|
||||||
format_instructions=format_instructions,
|
|
||||||
input_variables=input_variables,
|
|
||||||
memory_prompts=memory_prompts,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prompt = cls.create_completion_prompt(
|
|
||||||
tools,
|
|
||||||
prefix=prefix,
|
|
||||||
format_instructions=format_instructions,
|
|
||||||
input_variables=input_variables
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_chain = LLMChain(
|
|
||||||
model_config=model_config,
|
|
||||||
prompt=prompt,
|
|
||||||
callback_manager=callback_manager,
|
|
||||||
parameters={
|
|
||||||
'temperature': 0.2,
|
|
||||||
'top_p': 0.3,
|
|
||||||
'max_tokens': 1500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
_output_parser = output_parser
|
|
||||||
return cls(
|
|
||||||
llm_chain=llm_chain,
|
|
||||||
allowed_tools=tool_names,
|
|
||||||
output_parser=_output_parser,
|
|
||||||
dataset_tools=tools,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@ -1,117 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from langchain.agents import AgentExecutor as LCAgentExecutor
|
|
||||||
from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import BaseModel, Extra
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
||||||
from core.entities.agent_entities import PlanningStrategy
|
|
||||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
|
||||||
from core.helper import moderation
|
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
|
||||||
from core.rag.retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
|
||||||
from core.rag.retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
|
||||||
from core.rag.retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
|
||||||
from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
|
|
||||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfiguration(BaseModel):
|
|
||||||
strategy: PlanningStrategy
|
|
||||||
model_config: ModelConfigWithCredentialsEntity
|
|
||||||
tools: list[BaseTool]
|
|
||||||
summary_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
|
||||||
memory: Optional[TokenBufferMemory] = None
|
|
||||||
callbacks: Callbacks = None
|
|
||||||
max_iterations: int = 6
|
|
||||||
max_execution_time: Optional[float] = None
|
|
||||||
early_stopping_method: str = "generate"
|
|
||||||
# `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class AgentExecuteResult(BaseModel):
|
|
||||||
strategy: PlanningStrategy
|
|
||||||
output: Optional[str]
|
|
||||||
configuration: AgentConfiguration
|
|
||||||
|
|
||||||
|
|
||||||
class AgentExecutor:
|
|
||||||
def __init__(self, configuration: AgentConfiguration):
|
|
||||||
self.configuration = configuration
|
|
||||||
self.agent = self._init_agent()
|
|
||||||
|
|
||||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
|
||||||
if self.configuration.strategy == PlanningStrategy.ROUTER:
|
|
||||||
self.configuration.tools = [t for t in self.configuration.tools
|
|
||||||
if isinstance(t, DatasetRetrieverTool)
|
|
||||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
|
||||||
agent = MultiDatasetRouterAgent.from_llm_and_tools(
|
|
||||||
model_config=self.configuration.model_config,
|
|
||||||
tools=self.configuration.tools,
|
|
||||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
|
||||||
if self.configuration.memory else None,
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER:
|
|
||||||
self.configuration.tools = [t for t in self.configuration.tools
|
|
||||||
if isinstance(t, DatasetRetrieverTool)
|
|
||||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
|
||||||
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
|
|
||||||
model_config=self.configuration.model_config,
|
|
||||||
tools=self.configuration.tools,
|
|
||||||
output_parser=StructuredChatOutputParser(),
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}")
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
def should_use_agent(self, query: str) -> bool:
|
|
||||||
return self.agent.should_use_agent(query)
|
|
||||||
|
|
||||||
def run(self, query: str) -> AgentExecuteResult:
|
|
||||||
moderation_result = moderation.check_moderation(
|
|
||||||
self.configuration.model_config,
|
|
||||||
query
|
|
||||||
)
|
|
||||||
|
|
||||||
if moderation_result:
|
|
||||||
return AgentExecuteResult(
|
|
||||||
output="I apologize for any confusion, but I'm an AI assistant to be helpful, harmless, and honest.",
|
|
||||||
strategy=self.configuration.strategy,
|
|
||||||
configuration=self.configuration
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_executor = LCAgentExecutor.from_agent_and_tools(
|
|
||||||
agent=self.agent,
|
|
||||||
tools=self.configuration.tools,
|
|
||||||
max_iterations=self.configuration.max_iterations,
|
|
||||||
max_execution_time=self.configuration.max_execution_time,
|
|
||||||
early_stopping_method=self.configuration.early_stopping_method,
|
|
||||||
callbacks=self.configuration.callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
output = agent_executor.run(input=query)
|
|
||||||
except InvokeError as ex:
|
|
||||||
raise ex
|
|
||||||
except Exception as ex:
|
|
||||||
logging.exception("agent_executor run failed")
|
|
||||||
output = None
|
|
||||||
|
|
||||||
return AgentExecuteResult(
|
|
||||||
output=output,
|
|
||||||
strategy=self.configuration.strategy,
|
|
||||||
configuration=self.configuration
|
|
||||||
)
|
|
||||||
Loading…
Reference in New Issue