refactor(node_execution_entities): Improve the type of metadata

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/19430/head
-LAN- 1 year ago
parent d4fb9c7472
commit efcad48e91
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
from models.workflow import WorkflowNodeExecutionStatus
@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse):
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
inputs: Optional[Mapping[str, Any]] = None
process_data: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel):
description="The status message of the span. Additional field for context of the event. E.g. the error "
"message of an error event.",
)
input: Optional[Union[str, dict[str, Any], list, None]] = Field(
input: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The input of the span. Can be any JSON object."
)
output: Optional[Union[str, dict[str, Any], list, None]] = Field(
output: Optional[Union[str, Mapping[str, Any], list, None]] = Field(
default=None, description="The output of the span. Can be any JSON object."
)
version: Optional[str] = Field(

@ -160,7 +160,7 @@ class LangFuseDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = execution_metadata.copy()
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,

@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
from typing import Any, Optional, Union
@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel):
class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
name: Optional[str] = Field(..., description="Name of the run")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run")
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run")
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run")
run_type: LangSmithRunType = Field(..., description="Type of the run")
start_time: Optional[datetime | str] = Field(None, description="Start time of the run")
end_time: Optional[datetime | str] = Field(None, description="End time of the run")

@ -28,6 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -184,8 +185,8 @@ class LangSmithDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get("total_tokens", 0)
metadata = execution_metadata.copy()
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
metadata = {str(key): value for key, value in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,

@ -22,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -197,7 +198,7 @@ class OpikDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
metadata = execution_metadata.copy()
metadata = {str(k): v for k, v in execution_metadata.items()}
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
@ -243,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance):
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0)
total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
span_data = {
"trace_id": opik_trace_id,

@ -1,3 +1,4 @@
from collections.abc import Mapping
from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator
@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel):
class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
id: str = Field(..., description="ID of the trace")
op: str = Field(..., description="Name of the operation")
inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace")
outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace")
attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
None, description="Metadata and attributes associated with trace"
)

@ -23,6 +23,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -178,8 +179,8 @@ class WeaveDataTrace(BaseTraceInstance):
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = node_execution.metadata if node_execution.metadata else {}
node_total_tokens = execution_metadata.get("total_tokens", 0)
attributes = execution_metadata.copy()
node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0
attributes = {str(k): v for k, v in execution_metadata.items()}
attributes.update(
{
"workflow_run_id": trace_info.workflow_run_id,

@ -13,6 +13,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.nodes.enums import NodeType
@ -55,9 +56,9 @@ class NodeExecution(BaseModel):
title: str # Display title of the node
# Execution data
inputs: Optional[dict[str, Any]] = None # Input variables used by this node
process_data: Optional[dict[str, Any]] = None # Intermediate processing data
outputs: Optional[dict[str, Any]] = None # Output variables produced by this node
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
# Execution state
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
@ -65,7 +66,7 @@ class NodeExecution(BaseModel):
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Optional[dict[str, Any]] = None # Execution metadata (tokens, cost, etc.)
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
@ -76,7 +77,7 @@ class NodeExecution(BaseModel):
inputs: Optional[Mapping[str, Any]] = None,
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
) -> None:
"""
Update the model from mappings.

@ -292,9 +292,9 @@ class WorkflowCycleManager:
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id,
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = NodeExecution(
@ -332,7 +332,7 @@ class WorkflowCycleManager:
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@ -377,7 +377,7 @@ class WorkflowCycleManager:
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
@ -417,16 +417,16 @@ class WorkflowCycleManager:
# Convert metadata keys to strings
origin_metadata = {
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id,
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
# Convert execution metadata keys to strings
execution_metadata_dict = {}
execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
execution_metadata_dict[key] = value
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

Loading…
Cancel
Save