From 5d3cea5b95a50d27f7d7782d77dd3a04421f46b4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 21 May 2025 18:20:10 +0800 Subject: [PATCH] refactor(workflow_cycle_manager): Refactors `_handle_workflow_run_start` to use WorkflowExecution Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 12 ++-- api/core/app/entities/task_entities.py | 2 +- .../workflow_app_generate_task_pipeline.py | 10 ++- api/core/workflow/workflow_cycle_manager.py | 64 +++++++------------ 4 files changed, 34 insertions(+), 54 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 6bf69cee6e..a62d0f8dcf 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -297,21 +297,19 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, ) - self._workflow_run_id = workflow_run.id + self._workflow_run_id = workflow_execution.id message = self._get_message(session=session) if not message: raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_run.id + message.workflow_run_id = workflow_execution.id workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_start_resp elif isinstance( diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 0c2d617f80..bf0ce9fbf0 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse): id: str workflow_id: str sequence_number: int - inputs: dict + inputs: Mapping[str, Any] created_at: int event: StreamEvent = StreamEvent.WORKFLOW_STARTED diff --git a/api/core/workflow/workflow_app_generate_task_pipeline.py b/api/core/workflow/workflow_app_generate_task_pipeline.py index a70c635963..f7350236a2 100644 --- a/api/core/workflow/workflow_app_generate_task_pipeline.py +++ b/api/core/workflow/workflow_app_generate_task_pipeline.py @@ -261,17 +261,15 @@ class WorkflowAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = self._workflow_cycle_manager._handle_workflow_run_start( session=session, workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, ) - self._workflow_run_id = workflow_run.id + self._workflow_run_id = workflow_execution.id start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield start_resp elif isinstance( diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index b6fc2fc5c2..349163f8ad 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -8,7 +8,7 @@ from uuid import uuid4 from sqlalchemy import func, select from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, @@ -54,6 +54,7 @@ from core.workflow.entities.node_execution_entities import ( NodeExecution, NodeExecutionStatus, ) +from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -68,7 +69,6 @@ from models import ( WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus, - WorkflowRunTriggeredFrom, ) @@ -92,9 +92,7 @@ class WorkflowCycleManager: *, session: Session, workflow_id: str, - user_id: str, - created_by_role: CreatorUserRole, - ) -> WorkflowRun: + ) -> WorkflowExecution: workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow = session.scalar(workflow_stmt) if not workflow: @@ -113,38 +111,26 @@ class WorkflowCycleManager: continue inputs[f"sys.{key.value}"] = value - triggered_from = ( - WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN - ) - # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) - - workflow_run = WorkflowRun() - workflow_run.id = workflow_run_id - workflow_run.tenant_id = workflow.tenant_id - workflow_run.app_id = workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = workflow.id - workflow_run.type = workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = workflow.version - workflow_run.graph = workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = created_by_role - workflow_run.created_by = user_id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_run) + execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) + execution = WorkflowExecution.new( + id=execution_id, + workflow_id=workflow.id, + sequence_number=new_sequence_number, + type=WorkflowType(workflow.type), + workflow_version=workflow.version, + graph=workflow.graph_dict, + inputs=inputs, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) - return workflow_run + self._workflow_execution_repository.save(execution) + + return execution def _handle_workflow_run_success( self, @@ -462,20 +448,18 @@ class WorkflowCycleManager: def _workflow_start_to_stream_response( self, *, - session: Session, task_id: str, - workflow_run: WorkflowRun, + workflow_execution: WorkflowExecution, ) -> WorkflowStartStreamResponse: - _ = session return WorkflowStartStreamResponse( task_id=task_id, - workflow_run_id=workflow_run.id, + workflow_run_id=workflow_execution.id, data=WorkflowStartStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - sequence_number=workflow_run.sequence_number, - inputs=dict(workflow_run.inputs_dict or {}), - created_at=int(workflow_run.created_at.timestamp()), + id=workflow_execution.id, + workflow_id=workflow_execution.workflow_id, + sequence_number=workflow_execution.sequence_number, + inputs=workflow_execution.inputs, + created_at=int(workflow_execution.started_at.timestamp()), ), )