From 9d79fdfe513403da8e60f87a59da60bcaa1daee7 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 9 May 2025 12:24:13 +0800 Subject: [PATCH] Refactor workflow cycle manager to use domain model instead of DB model --- api/core/workflow/workflow_cycle_manager.py | 307 ++++++++++---------- api/models/model.py | 2 +- 2 files changed, 157 insertions(+), 152 deletions(-) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 01d5db4303..668f6f5631 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -46,26 +46,29 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.exc import WorkflowRunNotFoundError from core.file import FILE_MODEL_IDENTITY, File -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import ( + NodeExecution, + NodeExecutionStatus, +) from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_entry import WorkflowEntry -from models.account import Account -from models.enums import CreatedByRole, WorkflowRunTriggeredFrom -from models.model import EndUser -from models.workflow import ( +from models import ( + Account, + CreatedByRole, + EndUser, Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, + WorkflowRunTriggeredFrom, ) @@ -78,7 +81,6 @@ class WorkflowCycleManager: workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._workflow_run: WorkflowRun | None = None - self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables self._workflow_node_execution_repository = workflow_node_execution_repository @@ -258,21 +260,22 @@ class WorkflowCycleManager: workflow_run.exceptions_count = exceptions_count # Use the instance repository to find running executions for a workflow run - running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + running_domain_executions = self._workflow_node_execution_repository.get_running_executions( workflow_run_id=workflow_run.id ) - # Update the cache with the retrieved executions - for execution in running_workflow_node_executions: - if execution.node_execution_id: - self._workflow_node_executions[execution.node_execution_id] = execution + # Update the domain models + now = datetime.now(UTC).replace(tzinfo=None) + for domain_execution in running_domain_executions: + if domain_execution.node_execution_id: + # Update the domain model + domain_execution.status = NodeExecutionStatus.FAILED + domain_execution.error = error + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() - for workflow_node_execution in running_workflow_node_executions: - now = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.finished_at = now - workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) if trace_manager: trace_manager.add_trace_task( @@ -286,63 +289,70 @@ class WorkflowCycleManager: return workflow_run - def _handle_node_execution_start( - self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, - } + def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: + # Create a domain model + created_at = datetime.now(UTC).replace(tzinfo=None) + metadata = { + str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, + str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, + str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, + } + + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + index=event.node_run_index, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + status=NodeExecutionStatus.RUNNING, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by, + metadata=metadata, + created_at=created_at, ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + return domain_execution - def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - execution_metadata_dict = dict(event.execution_metadata or {}) - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[str(key)] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - process_data = WorkflowEntry.handle_special_values(event.process_data) + # Update domain model + domain_execution.status = NodeExecutionStatus.SUCCEEDED + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - # Use the instance repository to update the workflow node execution - self._workflow_node_execution_repository.update(workflow_node_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_failed( self, @@ -351,43 +361,52 @@ class WorkflowCycleManager: | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[str(key)] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata = ( - json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None - ) - process_data = WorkflowEntry.handle_special_values(event.process_data) - workflow_node_execution.status = ( - WorkflowNodeExecutionStatus.FAILED.value + + # Update domain model + domain_execution.status = ( + NodeExecutionStatus.FAILED if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value + else NodeExecutionStatus.EXCEPTION ) - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.execution_metadata = execution_metadata + domain_execution.error = event.error + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - self._workflow_node_execution_repository.update(workflow_node_execution) + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_retried( self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param workflow_run: workflow run @@ -399,47 +418,50 @@ class WorkflowCycleManager: elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings origin_metadata = { - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, + str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id, + str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id, + str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id, } - merged_metadata = ( - {**jsonable_encoder(event.execution_metadata), **origin_metadata} - if event.execution_metadata is not None - else origin_metadata + + # Convert execution metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[str(key)] = value + + merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata + + # Create a domain model + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + status=NodeExecutionStatus.RETRY, + created_by_role=workflow_run.created_by_role, + created_by=workflow_run.created_by, + created_at=created_at, + finished_at=finished_at, + elapsed_time=elapsed_time, + error=event.error, + index=event.node_run_index, ) - execution_metadata = json.dumps(merged_metadata) - - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.created_at = created_at - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.index = event.node_run_index - - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) - - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + + # Update with mappings + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) + + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) + + return domain_execution def _workflow_start_to_stream_response( self, @@ -515,7 +537,7 @@ class WorkflowCycleManager: *, event: QueueNodeStartedEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeStartStreamResponse]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None @@ -532,7 +554,7 @@ class WorkflowCycleManager: title=workflow_node_execution.title, index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, + inputs=workflow_node_execution.inputs, created_at=int(workflow_node_execution.created_at.timestamp()), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, @@ -565,7 +587,7 @@ class WorkflowCycleManager: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeFinishStreamResponse]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None @@ -584,16 +606,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -608,7 +630,7 @@ class WorkflowCycleManager: *, event: QueueNodeRetryEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None @@ -627,16 +649,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -908,23 +930,6 @@ class WorkflowCycleManager: return workflow_run - def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - # First check the cache for performance - if node_execution_id in self._workflow_node_executions: - cached_execution = self._workflow_node_executions[node_execution_id] - # No need to merge with session since expire_on_commit=False - return cached_execution - - # If not in cache, use the instance repository to get by node_execution_id - execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) - - if not execution: - raise ValueError(f"Workflow node execution not found: {node_execution_id}") - - # Update cache - self._workflow_node_executions[node_execution_id] = execution - return execution - def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ Handle agent log diff --git a/api/models/model.py b/api/models/model.py index ab426649c5..efb1442010 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True)