|
|
|
|
@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
|
|
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
|
|
|
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|
|
|
|
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
|
|
|
|
from core.variables import (
|
|
|
|
|
ArrayAnySegment,
|
|
|
|
|
ArrayFileSegment,
|
|
|
|
|
ArraySegment,
|
|
|
|
|
FileSegment,
|
|
|
|
|
NoneSegment,
|
|
|
|
|
ObjectSegment,
|
|
|
|
|
StringSegment,
|
|
|
|
|
)
|
|
|
|
|
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
|
|
|
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
|
|
|
|
from core.workflow.enums import SystemVariableKey
|
|
|
|
|
@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
return variables
|
|
|
|
|
|
|
|
|
|
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
|
|
|
|
variable = variable_selector.variable
|
|
|
|
|
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
|
|
|
|
variable_name = variable_selector.variable
|
|
|
|
|
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
|
|
|
|
if variable is None:
|
|
|
|
|
raise ValueError(f"Variable {variable_selector.variable} not found")
|
|
|
|
|
|
|
|
|
|
def parse_dict(d: dict) -> str:
|
|
|
|
|
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Parse dict into string
|
|
|
|
|
"""
|
|
|
|
|
# check if it's a context structure
|
|
|
|
|
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
|
|
|
|
return d["content"]
|
|
|
|
|
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
|
|
|
|
return input_dict["content"]
|
|
|
|
|
|
|
|
|
|
# else, parse the dict
|
|
|
|
|
try:
|
|
|
|
|
return json.dumps(d, ensure_ascii=False)
|
|
|
|
|
return json.dumps(input_dict, ensure_ascii=False)
|
|
|
|
|
except Exception:
|
|
|
|
|
return str(d)
|
|
|
|
|
return str(input_dict)
|
|
|
|
|
|
|
|
|
|
if isinstance(value, str):
|
|
|
|
|
value = value
|
|
|
|
|
elif isinstance(value, list):
|
|
|
|
|
if isinstance(variable, ArraySegment):
|
|
|
|
|
result = ""
|
|
|
|
|
for item in value:
|
|
|
|
|
for item in variable.value:
|
|
|
|
|
if isinstance(item, dict):
|
|
|
|
|
result += parse_dict(item)
|
|
|
|
|
elif isinstance(item, str):
|
|
|
|
|
result += item
|
|
|
|
|
elif isinstance(item, int | float):
|
|
|
|
|
result += str(item)
|
|
|
|
|
else:
|
|
|
|
|
result += str(item)
|
|
|
|
|
result += "\n"
|
|
|
|
|
value = result.strip()
|
|
|
|
|
elif isinstance(value, dict):
|
|
|
|
|
value = parse_dict(value)
|
|
|
|
|
elif isinstance(value, int | float):
|
|
|
|
|
value = str(value)
|
|
|
|
|
elif isinstance(variable, ObjectSegment):
|
|
|
|
|
value = parse_dict(variable.value)
|
|
|
|
|
else:
|
|
|
|
|
value = str(value)
|
|
|
|
|
value = variable.text
|
|
|
|
|
|
|
|
|
|
variables[variable] = value
|
|
|
|
|
variables[variable_name] = value
|
|
|
|
|
|
|
|
|
|
return variables
|
|
|
|
|
|
|
|
|
|
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
|
|
|
|
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
|
|
|
|
|
inputs = {}
|
|
|
|
|
prompt_template = node_data.prompt_template
|
|
|
|
|
|
|
|
|
|
@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
if not node_data.context.variable_selector:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
|
|
|
|
if context_value:
|
|
|
|
|
if isinstance(context_value, str):
|
|
|
|
|
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
|
|
|
|
elif isinstance(context_value, list):
|
|
|
|
|
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
|
|
|
|
if context_value_variable:
|
|
|
|
|
if isinstance(context_value_variable, StringSegment):
|
|
|
|
|
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
|
|
|
|
elif isinstance(context_value_variable, ArraySegment):
|
|
|
|
|
context_str = ""
|
|
|
|
|
original_retriever_resource = []
|
|
|
|
|
for item in context_value:
|
|
|
|
|
for item in context_value_variable.value:
|
|
|
|
|
if isinstance(item, str):
|
|
|
|
|
context_str += item + "\n"
|
|
|
|
|
else:
|
|
|
|
|
@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# get conversation id
|
|
|
|
|
conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
|
|
|
|
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
|
|
|
|
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
|
|
|
|
)
|
|
|
|
|
if conversation_id is None:
|
|
|
|
|
if not isinstance(conversation_id_variable, StringSegment):
|
|
|
|
|
return None
|
|
|
|
|
conversation_id = conversation_id_variable.value
|
|
|
|
|
|
|
|
|
|
# get conversation
|
|
|
|
|
conversation = (
|
|
|
|
|
|