fix: The statistics page cannot display the tokens consumed by agent node (#21861)

pull/21877/head
Novice 11 months ago committed by GitHub
parent ebc4fdc4b2
commit f3c8625fe2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
latency=0.0, latency=0.0,
) )
@classmethod
def from_metadata(cls, metadata: dict) -> "LLMUsage":
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: Dictionary containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
total_tokens = metadata.get("total_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
)
def plus(self, other: "LLMUsage") -> "LLMUsage": def plus(self, other: "LLMUsage") -> "LLMUsage":
""" """
Add two LLMUsage instances together. Add two LLMUsage instances together.

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -208,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]):
agent_logs: list[AgentLogEvent] = [] agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for message in message_stream: for message in message_stream:
@ -276,9 +277,10 @@ class ToolNode(BaseNode[ToolNodeData]):
elif message.type == ToolInvokeMessage.MessageType.JSON: elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage) assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if self.node_type == NodeType.AGENT: if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {}) msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(msg_metadata)
agent_execution_metadata = { agent_execution_metadata = {
key: value WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items() for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values() if key in WorkflowNodeExecutionMetadataKey.__members__.values()
} }
@ -377,6 +379,7 @@ class ToolNode(BaseNode[ToolNodeData]):
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
}, },
inputs=parameters_for_log, inputs=parameters_for_log,
llm_usage=llm_usage,
) )
) )

Loading…
Cancel
Save