ensure usage is present in process_data for LLM nodes

pull/21766/head
Davide Delbianco 11 months ago
parent 2270e41ec8
commit e8193afdef
No known key found for this signature in database
GPG Key ID: 3C00412F2A31305E

@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]):
jinja2_variables=self.node_data.prompt_config.jinja2_variables, jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# handle invoke result # handle invoke result
generator = self._invoke_llm( generator = self._invoke_llm(
node_data_model=self.node_data.model, node_data_model=self.node_data.model,
@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]):
elif isinstance(event, LLMStructuredOutput): elif isinstance(event, LLMStructuredOutput):
structured_output = event structured_output = event
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if structured_output: if structured_output:
outputs["structured_output"] = structured_output.structured_output outputs["structured_output"] = structured_output.structured_output

@ -19,24 +19,16 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import ( from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
LargeLanguageModel,
)
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ( from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
ChatModelMessage,
CompletionModelPromptTemplate,
)
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import ( from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.nodes.llm import ModelConfig, llm_utils
@ -112,10 +104,7 @@ class ParameterExtractorNode(BaseNode):
"model": { "model": {
"prompt_templates": { "prompt_templates": {
"completion_model": { "completion_model": {
"conversation_histories_role": { "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"user_prefix": "Human",
"assistant_prefix": "Assistant",
},
"stop": ["Human:"], "stop": ["Human:"],
} }
} }
@ -205,7 +194,8 @@ class ParameterExtractorNode(BaseNode):
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages model_mode=model_config.mode, prompt_messages=prompt_messages
), ),
"function": ({} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0])), "usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None, "tool_call": None,
"model_provider": model_config.provider, "model_provider": model_config.provider,
"model_name": model_config.model, "model_name": model_config.model,
@ -219,6 +209,7 @@ class ParameterExtractorNode(BaseNode):
tools=prompt_message_tools, tools=prompt_message_tools,
stop=model_config.stop, stop=model_config.stop,
) )
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call) process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text process_data["llm_text"] = text
except ParameterExtractorNodeError as e: except ParameterExtractorNodeError as e:
@ -235,11 +226,7 @@ class ParameterExtractorNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs={ outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
"__is_success": 0,
"__reason": "Failed to invoke model",
"__error": str(e),
},
error=str(e), error=str(e),
metadata={}, metadata={},
) )
@ -377,8 +364,7 @@ class ParameterExtractorNode(BaseNode):
], ],
), ),
ToolPromptMessage( ToolPromptMessage(
content="Great! You have called the function with the correct parameters.", content="Great! You have called the function with the correct parameters.", tool_call_id=id
tool_call_id=id,
), ),
AssistantPromptMessage( AssistantPromptMessage(
content="I have extracted the parameters, let's move on.", content="I have extracted the parameters, let's move on.",
@ -452,18 +438,10 @@ class ParameterExtractorNode(BaseNode):
""" """
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token( rest_token = self._calculate_rest_token(
node_data=node_data, node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
query=query,
variable_pool=variable_pool,
model_config=model_config,
context="",
) )
prompt_template = self._get_prompt_engineering_prompt_template( prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
query=query,
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
) )
prompt_messages = prompt_transform.get_prompt( prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, prompt_template=prompt_template,
@ -494,11 +472,7 @@ class ParameterExtractorNode(BaseNode):
""" """
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token( rest_token = self._calculate_rest_token(
node_data=node_data, node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
query=query,
variable_pool=variable_pool,
model_config=model_config,
context="",
) )
prompt_template = self._get_prompt_engineering_prompt_template( prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, node_data=node_data,
@ -727,8 +701,7 @@ class ParameterExtractorNode(BaseNode):
if memory and node_data.memory and node_data.memory.window: if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
message_limit=node_data.memory.window.size,
) )
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
@ -755,8 +728,7 @@ class ParameterExtractorNode(BaseNode):
if memory and node_data.memory and node_data.memory.window: if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
message_limit=node_data.memory.window.size,
) )
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage( system_prompt_messages = ChatModelMessage(
@ -795,10 +767,7 @@ class ParameterExtractorNode(BaseNode):
if not model_schema: if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found") raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & { if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.MULTI_TOOL_CALL,
}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else: else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)

@ -11,10 +11,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import ( from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import ( from core.workflow.nodes.llm import (
@ -143,6 +140,8 @@ class QuestionClassifierNode(LLMNode):
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages model_mode=model_config.mode, prompt_messages=prompt_messages
), ),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider, "model_provider": model_config.provider,
"model_name": model_config.model, "model_name": model_config.model,
} }
@ -150,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
"class_name": category_name, "class_name": category_name,
"class_id": category_id, "class_id": category_id,
"usage": jsonable_encoder(usage), "usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
} }
return NodeRunResult( return NodeRunResult(
@ -240,8 +238,7 @@ class QuestionClassifierNode(LLMNode):
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens: if model_context_tokens:
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
model=model_config.model,
) )
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
@ -280,13 +277,12 @@ class QuestionClassifierNode(LLMNode):
if memory: if memory:
memory_str = memory.get_history_prompt_text( memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, max_token_limit=max_token_limit,
message_limit=(node_data.memory.window.size if node_data.memory and node_data.memory.window else None), message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
) )
prompt_messages: list[LLMNodeChatModelMessage] = [] prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT: if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage( system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str),
) )
prompt_messages.append(system_prompt_messages) prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage( user_prompt_message_1 = LLMNodeChatModelMessage(
@ -294,8 +290,7 @@ class QuestionClassifierNode(LLMNode):
) )
prompt_messages.append(user_prompt_message_1) prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage( assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
) )
prompt_messages.append(assistant_prompt_message_1) prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage( user_prompt_message_2 = LLMNodeChatModelMessage(
@ -303,8 +298,7 @@ class QuestionClassifierNode(LLMNode):
) )
prompt_messages.append(user_prompt_message_2) prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage( assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
) )
prompt_messages.append(assistant_prompt_message_2) prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage( user_prompt_message_3 = LLMNodeChatModelMessage(

@ -520,7 +520,6 @@ class DraftVariableSaver:
_EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = {
NodeType.LLM: frozenset(["finish_reason"]), NodeType.LLM: frozenset(["finish_reason"]),
NodeType.LOOP: frozenset(["loop_round"]), NodeType.LOOP: frozenset(["loop_round"]),
NodeType.QUESTION_CLASSIFIER: frozenset(["finish_reason"]),
} }
# Database session used for persisting draft variables. # Database session used for persisting draft variables.

@ -555,7 +555,7 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [
type: VarType.string, type: VarType.string,
}, },
{ {
variable: 'usage', variable: '__usage',
type: VarType.object, type: VarType.object,
}, },
] ]

Loading…
Cancel
Save