refactor(node_execution_entities): Improve the type of metadata

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/19430/head
-LAN- 1 year ago
parent d4fb9c7472
commit efcad48e91
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
created_at: int created_at: int
extras: dict = {} extras: dict = {}
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
status: str status: str
error: Optional[str] = None error: Optional[str] = None
elapsed_time: float elapsed_time: float
execution_metadata: Optional[dict] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
title: str title: str
index: int index: int
predecessor_node_id: Optional[str] = None predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[dict] = None process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[dict] = None outputs: Optional[Mapping[str, Any]] = None
status: str status: str
error: Optional[str] = None error: Optional[str] = None
elapsed_time: float elapsed_time: float
execution_metadata: Optional[dict] = None execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
description="The status message of the span. Additional field for context of the event. E.g. the error " description="The status message of the span. Additional field for context of the event. E.g. the error "
"message of an error event.", "message of an error event.",
) )
input: Optional[Union[str, dict[str, Any], list, None]] = Field( input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The input of the span. Can be any JSON object." default=None, description="The input of the span. Can be any JSON object."
) )
output: Optional[Union[str, dict[str, Any], list, None]] = Field( output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The output of the span. Can be any JSON object." default=None, description="The output of the span. Can be any JSON object."
) )
version: Optional[str] = Field( version: Optional[str] = Field(

@ -160,7 +160,7 @@ class LangFuseDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = execution_metadata.copy() metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
name: Optional[str] = Field(..., description="Name of the run") name: Optional[str] = Field(..., description="Name of the run")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
run_type: LangSmithRunType = Field(..., description="Type of the run") run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: Optional[datetime | str] = Field(None, description="Start time of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
end_time: Optional[datetime | str] = Field(None, description="End time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run")

@ -28,6 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -184,8 +185,8 @@ class LangSmithDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get("total_tokens", 0) node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
metadata = execution_metadata.copy() metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,

@ -22,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -197,7 +198,7 @@ class OpikDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = execution_metadata.copy() metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update( metadata.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,
@ -243,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance):
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens: if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0) total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
span_data = { span_data = {
"trace_id": opik_trace_id, "trace_id": opik_trace_id,

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace") id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation") op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
None, description="Metadata and attributes associated with trace" None, description="Metadata and attributes associated with trace"
) )

@ -23,6 +23,7 @@ from core.ops.entities.trace_entity import (
) )
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -178,8 +179,8 @@ class WeaveDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time) finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {} execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get("total_tokens", 0) node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
attributes = execution_metadata.copy() attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update( attributes.update(
{ {
"workflow_run_id": trace_info.workflow_run_id, "workflow_run_id": trace_info.workflow_run_id,

@ -13,6 +13,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
@ -55,9 +56,9 @@ class NodeExecution(BaseModel):
title: str # Display title of the node title: str # Display title of the node
# Execution data # Execution data
inputs: Optional[dict[str, Any]] = None # Input variables used by this node inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[dict[str, Any]] = None # Intermediate processing data process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[dict[str, Any]] = None # Output variables produced by this node outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# Execution state # Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
@ -65,7 +66,7 @@ class NodeExecution(BaseModel):
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata # Additional metadata
metadata: Optional[dict[str, Any]] = None # Execution metadata (tokens, cost, etc.) metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
# Timing information # Timing information
created_at: datetime # When execution started created_at: datetime # When execution started
@ -76,7 +77,7 @@ class NodeExecution(BaseModel):
inputs: Optional[Mapping[str, Any]] = None, inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None, process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None, outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[str, Any]] = None, metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
) -> None: ) -> None:
""" """
Update the model from mappings. Update the model from mappings.

@ -292,9 +292,9 @@ class WorkflowCycleManager:
# Create a domain model # Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None) created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = { metadata = {
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
} }
domain_execution = NodeExecution( domain_execution = NodeExecution(
@ -332,7 +332,7 @@ class WorkflowCycleManager:
execution_metadata_dict = {} execution_metadata_dict = {}
if event.execution_metadata: if event.execution_metadata:
for key, value in event.execution_metadata.items(): for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
@ -377,7 +377,7 @@ class WorkflowCycleManager:
execution_metadata_dict = {} execution_metadata_dict = {}
if event.execution_metadata: if event.execution_metadata:
for key, value in event.execution_metadata.items(): for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
@ -417,16 +417,16 @@ class WorkflowCycleManager:
# Convert metadata keys to strings # Convert metadata keys to strings
origin_metadata = { origin_metadata = {
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
} }
# Convert execution metadata keys to strings # Convert execution metadata keys to strings
execution_metadata_dict = {} execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
if event.execution_metadata: if event.execution_metadata:
for key, value in event.execution_metadata.items(): for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value execution_metadata_dict[key] = value
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

Loading…
Cancel
Save