From efcad48e91961e2cd18210c9931f85545b4514a8 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 14 May 2025 15:26:52 +0800 Subject: [PATCH] refactor(node_execution_entities): Improve the type of metadata Signed-off-by: -LAN- --- api/core/app/entities/task_entities.py | 20 +++++++++---------- .../entities/langfuse_trace_entity.py | 5 +++-- api/core/ops/langfuse_trace/langfuse_trace.py | 2 +- .../entities/langsmith_trace_entity.py | 5 +++-- .../ops/langsmith_trace/langsmith_trace.py | 5 +++-- api/core/ops/opik_trace/opik_trace.py | 5 +++-- .../entities/weave_trace_entity.py | 5 +++-- api/core/ops/weave_trace/weave_trace.py | 5 +++-- .../entities/node_execution_entities.py | 11 +++++----- api/core/workflow/workflow_cycle_manager.py | 20 +++++++++---------- 10 files changed, 45 insertions(+), 38 deletions(-) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 817699bd20..0c2d617f80 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import AgentNodeStrategyInit +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey from models.workflow import WorkflowNodeExecutionStatus @@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None created_at: int extras: dict = {} parallel_id: Optional[str] = None @@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] @@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index f486da3a6d..46ba1c45b9 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel): description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The output of the span. Can be any JSON object." ) version: Optional[str] = Field( diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index bb46055689..120c36f53d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -160,7 +160,7 @@ class LangFuseDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - metadata = execution_metadata.copy() + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 348b7ba501..4fd01136ba 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 4420feca91..6631727c79 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,6 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -184,8 +185,8 @@ class LangSmithDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - node_total_tokens = execution_metadata.get("total_tokens", 0) - metadata = execution_metadata.copy() + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 24718c9393..c22df55357 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -22,6 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -197,7 +198,7 @@ class OpikDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - metadata = execution_metadata.copy() + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -243,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance): parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id if not total_tokens: - total_tokens = execution_metadata.get("total_tokens", 0) + total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index e423f5ccbb..7f489f37ac 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( None, description="Metadata and attributes associated with trace" ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index fb8627fb85..a4f38dfbba 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -23,6 +23,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -178,8 +179,8 @@ class WeaveDataTrace(BaseTraceInstance): finished_at = created_at + timedelta(seconds=elapsed_time) execution_metadata = node_execution.metadata if node_execution.metadata else {} - node_total_tokens = execution_metadata.get("total_tokens", 0) - attributes = execution_metadata.copy() + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { "workflow_run_id": trace_info.workflow_run_id, diff --git a/api/core/workflow/entities/node_execution_entities.py b/api/core/workflow/entities/node_execution_entities.py index 58da8a2cfe..5e5ead062f 100644 --- a/api/core/workflow/entities/node_execution_entities.py +++ b/api/core/workflow/entities/node_execution_entities.py @@ -13,6 +13,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.nodes.enums import NodeType @@ -55,9 +56,9 @@ class NodeExecution(BaseModel): title: str # Display title of the node # Execution data - inputs: Optional[dict[str, Any]] = None # Input variables used by this node - process_data: Optional[dict[str, Any]] = None # Intermediate processing data - outputs: Optional[dict[str, Any]] = None # Output variables produced by this node + inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node + process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data + outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node # Execution state status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status @@ -65,7 +66,7 @@ class NodeExecution(BaseModel): elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds # Additional metadata - metadata: Optional[dict[str, Any]] = None # Execution metadata (tokens, cost, etc.) + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) # Timing information created_at: datetime # When execution started @@ -76,7 +77,7 @@ class NodeExecution(BaseModel): inputs: Optional[Mapping[str, Any]] = None, process_data: Optional[Mapping[str, Any]] = None, outputs: Optional[Mapping[str, Any]] = None, - metadata: Optional[Mapping[str, Any]] = None, + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, ) -> None: """ Update the model from mappings. diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index a512d2022d..6d33d7372c 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -292,9 +292,9 @@ class WorkflowCycleManager: # Create a domain model created_at = datetime.now(UTC).replace(tzinfo=None) metadata = { - str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, - str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, - str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.LOOP_ID: event.in_loop_id, } domain_execution = NodeExecution( @@ -332,7 +332,7 @@ class WorkflowCycleManager: execution_metadata_dict = {} if event.execution_metadata: for key, value in event.execution_metadata.items(): - execution_metadata_dict[str(key)] = value + execution_metadata_dict[key] = value finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() @@ -377,7 +377,7 @@ class WorkflowCycleManager: execution_metadata_dict = {} if event.execution_metadata: for key, value in event.execution_metadata.items(): - execution_metadata_dict[str(key)] = value + execution_metadata_dict[key] = value finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() @@ -417,16 +417,16 @@ class WorkflowCycleManager: # Convert metadata keys to strings origin_metadata = { - str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, - str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, - str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.LOOP_ID: event.in_loop_id, } # Convert execution metadata keys to strings - execution_metadata_dict = {} + execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} if event.execution_metadata: for key, value in event.execution_metadata.items(): - execution_metadata_dict[str(key)] = value + execution_metadata_dict[key] = value merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata