ensure usage is present in process_data for LLM nodes

pull/21766/head
Davide Delbianco 11 months ago
parent 2270e41ec8
commit e8193afdef
No known key found for this signature in database
GPG Key ID: 3C00412F2A31305E

@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]):
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
)
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# handle invoke result
generator = self._invoke_llm(
node_data_model=self.node_data.model,
@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]):
elif isinstance(event, LLMStructuredOutput):
structured_output = event
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
if structured_output:
outputs["structured_output"] = structured_output.structured_output

@ -19,24 +19,16 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import (
LargeLanguageModel,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import (
ChatModelMessage,
CompletionModelPromptTemplate,
)
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
@ -112,10 +104,7 @@ class ParameterExtractorNode(BaseNode):
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant",
},
"conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"},
"stop": ["Human:"],
}
}
@ -205,7 +194,8 @@ class ParameterExtractorNode(BaseNode):
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"function": ({} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0])),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
"model_provider": model_config.provider,
"model_name": model_config.model,
@ -219,6 +209,7 @@ class ParameterExtractorNode(BaseNode):
tools=prompt_message_tools,
stop=model_config.stop,
)
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
process_data["llm_text"] = text
except ParameterExtractorNodeError as e:
@ -235,11 +226,7 @@ class ParameterExtractorNode(BaseNode):
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={
"__is_success": 0,
"__reason": "Failed to invoke model",
"__error": str(e),
},
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
@ -377,8 +364,7 @@ class ParameterExtractorNode(BaseNode):
],
),
ToolPromptMessage(
content="Great! You have called the function with the correct parameters.",
tool_call_id=id,
content="Great! You have called the function with the correct parameters.", tool_call_id=id
),
AssistantPromptMessage(
content="I have extracted the parameters, let's move on.",
@ -452,18 +438,10 @@ class ParameterExtractorNode(BaseNode):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
context="",
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
query=query,
variable_pool=variable_pool,
memory=memory,
max_token_limit=rest_token,
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
@ -494,11 +472,7 @@ class ParameterExtractorNode(BaseNode):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
context="",
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
@ -727,8 +701,7 @@ class ParameterExtractorNode(BaseNode):
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size,
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -755,8 +728,7 @@ class ParameterExtractorNode(BaseNode):
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size,
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@ -795,10 +767,7 @@ class ParameterExtractorNode(BaseNode):
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.MULTI_TOOL_CALL,
}:
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)

@ -11,10 +11,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
@ -143,6 +140,8 @@ class QuestionClassifierNode(LLMNode):
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
}
@ -150,7 +149,6 @@ class QuestionClassifierNode(LLMNode):
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
return NodeRunResult(
@ -240,8 +238,7 @@ class QuestionClassifierNode(LLMNode):
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model,
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
@ -280,13 +277,12 @@ class QuestionClassifierNode(LLMNode):
if memory:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit,
message_limit=(node_data.memory.window.size if node_data.memory and node_data.memory.window else None),
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str),
role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str)
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
@ -294,8 +290,7 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1,
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
@ -303,8 +298,7 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT,
text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2,
role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(

@ -520,7 +520,6 @@ class DraftVariableSaver:
_EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = {
NodeType.LLM: frozenset(["finish_reason"]),
NodeType.LOOP: frozenset(["loop_round"]),
NodeType.QUESTION_CLASSIFIER: frozenset(["finish_reason"]),
}
# Database session used for persisting draft variables.

@ -555,7 +555,7 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [
type: VarType.string,
},
{
variable: 'usage',
variable: '__usage',
type: VarType.object,
},
]

Loading…
Cancel
Save