refactor(node_execution_entities): Change node_type's type from str to NodeType

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/19430/head
-LAN- 1 year ago
parent 4664dfaba0
commit 84aaa4228e
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -29,6 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom
@ -149,7 +150,7 @@ class LangFuseDataTrace(BaseTraceInstance):
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}

@ -28,6 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -173,7 +174,7 @@ class LangSmithDataTrace(BaseTraceInstance):
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}
@ -207,7 +208,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""), "ls_model_name": process_data.get("model_name", ""),
} }
) )
elif node_type == "knowledge-retrieval": elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever run_type = LangSmithRunType.retriever
else: else:
run_type = LangSmithRunType.tool run_type = LangSmithRunType.tool

@ -22,6 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -186,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}

@ -23,6 +23,7 @@ from core.ops.entities.trace_entity import (
) )
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@ -167,7 +168,7 @@ class WeaveDataTrace(BaseTraceInstance):
node_name = node_execution.title node_name = node_execution.title
node_type = node_execution.node_type node_type = node_execution.node_type
status = node_execution.status status = node_execution.status
if node_type == "llm": if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else: else:
inputs = node_execution.inputs if node_execution.inputs else {} inputs = node_execution.inputs if node_execution.inputs else {}

@ -15,6 +15,7 @@ from core.workflow.entities.node_execution_entities import (
NodeExecution, NodeExecution,
NodeExecutionStatus, NodeExecutionStatus,
) )
from core.workflow.nodes.enums import NodeType
from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models import ( from models import (
Account, Account,
@ -113,7 +114,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
index=db_model.index, index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id, predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id, node_id=db_model.node_id,
node_type=db_model.node_type, node_type=NodeType(db_model.node_type),
title=db_model.title, title=db_model.title,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,

@ -13,6 +13,8 @@ from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.workflow.nodes.enums import NodeType
class NodeExecutionStatus(StrEnum): class NodeExecutionStatus(StrEnum):
""" """
@ -49,7 +51,7 @@ class NodeExecution(BaseModel):
index: int # Sequence number for ordering in trace visualization index: int # Sequence number for ordering in trace visualization
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
node_id: str # ID of the node being executed node_id: str # ID of the node being executed
node_type: str # Type of node (e.g., start, llm, knowledge) node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node title: str # Display title of the node
# Execution data # Execution data

@ -305,7 +305,7 @@ class WorkflowCycleManager:
index=event.node_run_index, index=event.node_run_index,
node_execution_id=event.node_execution_id, node_execution_id=event.node_execution_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type,
title=event.node_data.title, title=event.node_data.title,
status=NodeExecutionStatus.RUNNING, status=NodeExecutionStatus.RUNNING,
metadata=metadata, metadata=metadata,
@ -438,7 +438,7 @@ class WorkflowCycleManager:
predecessor_node_id=event.predecessor_node_id, predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id, node_execution_id=event.node_execution_id,
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type,
title=event.node_data.title, title=event.node_data.title,
status=NodeExecutionStatus.RETRY, status=NodeExecutionStatus.RETRY,
created_at=created_at, created_at=created_at,
@ -532,7 +532,7 @@ class WorkflowCycleManager:
task_id: str, task_id: str,
workflow_node_execution: NodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None
@ -582,7 +582,7 @@ class WorkflowCycleManager:
task_id: str, task_id: str,
workflow_node_execution: NodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None
@ -625,7 +625,7 @@ class WorkflowCycleManager:
task_id: str, task_id: str,
workflow_node_execution: NodeExecution, workflow_node_execution: NodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
return None return None

@ -220,7 +220,7 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
assert result.workflow_run_id == mock_workflow_run.id assert result.workflow_run_id == mock_workflow_run.id
assert result.node_execution_id == event.node_execution_id assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id assert result.node_id == event.node_id
assert result.node_type == event.node_type.value assert result.node_type == event.node_type
assert result.title == event.node_data.title assert result.title == event.node_data.title
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
# NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level

Loading…
Cancel
Save