From e8193afdef3d47ddb4182d13958fb365a9383cd6 Mon Sep 17 00:00:00 2001 From: Davide Delbianco Date: Wed, 2 Jul 2025 09:10:06 +0200 Subject: [PATCH] ensure usage is present in process_data for LLM nodes --- api/core/workflow/nodes/llm/node.py | 20 +++--- .../parameter_extractor_node.py | 61 +++++-------------- .../question_classifier_node.py | 22 +++---- .../workflow_draft_variable_service.py | 1 - web/app/components/workflow/constants.ts | 2 +- 5 files changed, 35 insertions(+), 71 deletions(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b5225ce548..9bfb402dc8 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]): jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) - process_data = { - "model_mode": model_config.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages - ), - "model_provider": model_config.provider, - "model_name": model_config.model, - } - # handle invoke result generator = self._invoke_llm( node_data_model=self.node_data.model, @@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]): elif isinstance(event, LLMStructuredOutput): structured_output = event + process_data = { + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, + "model_provider": model_config.provider, + "model_name": model_config.model, + } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} if structured_output: outputs["structured_output"] = structured_output.structured_output diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index cf2a78ee54..25a534256b 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -19,24 +19,16 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import ( - LargeLanguageModel, -) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ( - ChatModelMessage, - CompletionModelPromptTemplate, -) +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.entities.workflow_node_execution import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils @@ -112,10 +104,7 @@ class ParameterExtractorNode(BaseNode): "model": { "prompt_templates": { "completion_model": { - "conversation_histories_role": { - "user_prefix": "Human", - "assistant_prefix": "Assistant", - }, + "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "stop": ["Human:"], } } @@ -205,7 +194,8 @@ class ParameterExtractorNode(BaseNode): "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( model_mode=model_config.mode, prompt_messages=prompt_messages ), - "function": ({} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0])), + "usage": None, + "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), "tool_call": None, "model_provider": model_config.provider, "model_name": model_config.model, @@ -219,6 +209,7 @@ class ParameterExtractorNode(BaseNode): tools=prompt_message_tools, stop=model_config.stop, ) + process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) process_data["llm_text"] = text except ParameterExtractorNodeError as e: @@ -235,11 +226,7 @@ class ParameterExtractorNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, process_data=process_data, - outputs={ - "__is_success": 0, - "__reason": "Failed to invoke model", - "__error": str(e), - }, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, error=str(e), metadata={}, ) @@ -377,8 +364,7 @@ class ParameterExtractorNode(BaseNode): ], ), ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", - tool_call_id=id, + content="Great! You have called the function with the correct parameters.", tool_call_id=id ), AssistantPromptMessage( content="I have extracted the parameters, let's move on.", @@ -452,18 +438,10 @@ class ParameterExtractorNode(BaseNode): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_config=model_config, - context="", + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" ) prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=query, - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, + node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) prompt_messages = prompt_transform.get_prompt( prompt_template=prompt_template, @@ -494,11 +472,7 @@ class ParameterExtractorNode(BaseNode): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_config=model_config, - context="", + node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, @@ -727,8 +701,7 @@ class ParameterExtractorNode(BaseNode): if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size, + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( @@ -755,8 +728,7 @@ class ParameterExtractorNode(BaseNode): if memory and node_data.memory and node_data.memory.window: memory_str = memory.get_history_prompt_text( - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size, + max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) if model_mode == ModelMode.CHAT: system_prompt_messages = ChatModelMessage( @@ -795,10 +767,7 @@ class ParameterExtractorNode(BaseNode): if not model_schema: raise ModelSchemaNotFoundError("Model schema not found") - if set(model_schema.features or []) & { - ModelFeature.MULTI_TOOL_CALL, - ModelFeature.MULTI_TOOL_CALL, - }: + if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 43f14188ed..74024ed90c 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -11,10 +11,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.workflow_node_execution import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( @@ -143,6 +140,8 @@ class QuestionClassifierNode(LLMNode): "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( model_mode=model_config.mode, prompt_messages=prompt_messages ), + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, "model_provider": model_config.provider, "model_name": model_config.model, } @@ -150,7 +149,6 @@ class QuestionClassifierNode(LLMNode): "class_name": category_name, "class_id": category_id, "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, } return NodeRunResult( @@ -240,8 +238,7 @@ class QuestionClassifierNode(LLMNode): model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, - model=model_config.model, + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model ) curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) @@ -280,13 +277,12 @@ class QuestionClassifierNode(LLMNode): if memory: memory_str = memory.get_history_prompt_text( max_token_limit=max_token_limit, - message_limit=(node_data.memory.window.size if node_data.memory and node_data.memory.window else None), + message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] if model_mode == ModelMode.CHAT: system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str), + role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) prompt_messages.append(system_prompt_messages) user_prompt_message_1 = LLMNodeChatModelMessage( @@ -294,8 +290,7 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages.append(user_prompt_message_1) assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 ) prompt_messages.append(assistant_prompt_message_1) user_prompt_message_2 = LLMNodeChatModelMessage( @@ -303,8 +298,7 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages.append(user_prompt_message_2) assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, - text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, + role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 ) prompt_messages.append(assistant_prompt_message_2) user_prompt_message_3 = LLMNodeChatModelMessage( diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index e999bda73f..44fd72b5e4 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -520,7 +520,6 @@ class DraftVariableSaver: _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { NodeType.LLM: frozenset(["finish_reason"]), NodeType.LOOP: frozenset(["loop_round"]), - NodeType.QUESTION_CLASSIFIER: frozenset(["finish_reason"]), } # Database session used for persisting draft variables. diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index bc38a05359..0ef4dc9dea 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -555,7 +555,7 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [ type: VarType.string, }, { - variable: 'usage', + variable: '__usage', type: VarType.object, }, ]