diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 836fe2d3f4..bb46055689 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -29,6 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom @@ -149,7 +150,7 @@ class LangFuseDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": + if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index fea1f235eb..4420feca91 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,6 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -173,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": + if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} @@ -207,7 +208,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == "knowledge-retrieval": + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 7fdbf12da1..24718c9393 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -22,6 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -186,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": + if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 70826448b7..fb8627fb85 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -23,6 +23,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -167,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": + if node_type == NodeType.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs if node_execution.inputs else {} diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index f2d3dc7cd9..0715160171 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -15,6 +15,7 @@ from core.workflow.entities.node_execution_entities import ( NodeExecution, NodeExecutionStatus, ) +from core.workflow.nodes.enums import NodeType from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from models import ( Account, @@ -113,7 +114,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) index=db_model.index, predecessor_node_id=db_model.predecessor_node_id, node_id=db_model.node_id, - node_type=db_model.node_type, + node_type=NodeType(db_model.node_type), title=db_model.title, inputs=inputs, process_data=process_data, diff --git a/api/core/workflow/entities/node_execution_entities.py b/api/core/workflow/entities/node_execution_entities.py index 5fc8fc6073..58da8a2cfe 100644 --- a/api/core/workflow/entities/node_execution_entities.py +++ b/api/core/workflow/entities/node_execution_entities.py @@ -13,6 +13,8 @@ from typing import Any, Optional from pydantic import BaseModel, Field +from core.workflow.nodes.enums import NodeType + class NodeExecutionStatus(StrEnum): """ @@ -49,7 +51,7 @@ class NodeExecution(BaseModel): index: int # Sequence number for ordering in trace visualization predecessor_node_id: Optional[str] = None # ID of the node that executed before this one node_id: str # ID of the node being executed - node_type: str # Type of node (e.g., start, llm, knowledge) + node_type: NodeType # Type of node (e.g., start, llm, knowledge) title: str # Display title of the node # Execution data diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index bd9dd4005b..4be205d51e 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -305,7 +305,7 @@ class WorkflowCycleManager: index=event.node_run_index, node_execution_id=event.node_execution_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_data.title, status=NodeExecutionStatus.RUNNING, metadata=metadata, @@ -438,7 +438,7 @@ class WorkflowCycleManager: predecessor_node_id=event.predecessor_node_id, node_execution_id=event.node_execution_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_data.title, status=NodeExecutionStatus.RETRY, created_at=created_at, @@ -532,7 +532,7 @@ class WorkflowCycleManager: task_id: str, workflow_node_execution: NodeExecution, ) -> Optional[NodeStartStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -582,7 +582,7 @@ class WorkflowCycleManager: task_id: str, workflow_node_execution: NodeExecution, ) -> Optional[NodeFinishStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -625,7 +625,7 @@ class WorkflowCycleManager: task_id: str, workflow_node_execution: NodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index e49359662b..4f7da89020 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -220,7 +220,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): assert result.workflow_run_id == mock_workflow_run.id assert result.node_execution_id == event.node_execution_id assert result.node_id == event.node_id - assert result.node_type == event.node_type.value + assert result.node_type == event.node_type assert result.title == event.node_data.title assert result.status == WorkflowNodeExecutionStatus.RUNNING.value # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level