diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 0a2401f953..cb217dd76d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -62,12 +62,13 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_execution_entities import WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from core.workflow.workflow_cycle_manager import TempWorkflowEntity, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db from models import Conversation, EndUser, Message, MessageFile @@ -128,6 +129,12 @@ class AdvancedChatAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, }, + workflow_entity=TempWorkflowEntity( + id_=workflow.id, + type_=WorkflowType(workflow.type), + version=workflow.version, + graph=workflow.graph_dict, + ), workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index b187f2ca96..f66876e95c 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -55,11 +55,11 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.workflow_execution_entities import WorkflowExecution +from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from core.workflow.workflow_cycle_manager import TempWorkflowEntity, WorkflowCycleManager from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole @@ -116,6 +116,12 @@ class WorkflowAppGenerateTaskPipeline: SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_execution_id, }, + workflow_entity=TempWorkflowEntity( + id_=workflow.id, + type_=WorkflowType(workflow.type), + version=workflow.version, + graph=workflow.graph_dict, + ), workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index c41770e2eb..bc85f0a215 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,9 +1,9 @@ from collections.abc import Mapping +from dataclasses import dataclass from datetime import UTC, datetime from typing import Any, Optional, Union from uuid import uuid4 -from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity @@ -30,17 +30,25 @@ from core.workflow.repository.workflow_execution_repository import WorkflowExecu from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_entry import WorkflowEntry from models import ( - Workflow, WorkflowRunStatus, ) +@dataclass +class TempWorkflowEntity: + id_: str + type_: WorkflowType + version: str + graph: Mapping[str, Any] + + class WorkflowCycleManager: def __init__( self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], workflow_system_variables: dict[SystemVariableKey, Any], + workflow_entity: TempWorkflowEntity, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: @@ -48,6 +56,7 @@ class WorkflowCycleManager: self._workflow_system_variables = workflow_system_variables self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + self._temp_workflow_entity = workflow_entity def handle_workflow_run_start( self, @@ -55,11 +64,6 @@ class WorkflowCycleManager: session: Session, workflow_id: str, ) -> WorkflowExecution: - workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) - workflow = session.scalar(workflow_stmt) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_id}") - inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": @@ -74,10 +78,10 @@ class WorkflowCycleManager: execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) execution = WorkflowExecution.new( id=execution_id, - workflow_id=workflow.id, - type=WorkflowType(workflow.type), - workflow_version=workflow.version, - graph=workflow.graph_dict, + workflow_id=self._temp_workflow_entity.id_, + type=self._temp_workflow_entity.type_, + workflow_version=self._temp_workflow_entity.version, + graph=self._temp_workflow_entity.graph, inputs=inputs, started_at=datetime.now(UTC).replace(tzinfo=None), ) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 3e17fe0e4e..0d12406c49 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -19,7 +19,7 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from core.workflow.workflow_cycle_manager import TempWorkflowEntity, WorkflowCycleManager from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import ( @@ -93,16 +93,38 @@ def mock_workflow_execution_repository(): return repo +@pytest.fixture +def real_workflow_entity(): + return TempWorkflowEntity( + id_="test-workflow-id", # Matches ID used in other fixtures + type_=WorkflowType.CHAT, + version="1.0.0", + graph={ + "nodes": [ + { + "id": "node1", + "type": "chat", # NodeType is a string enum + "name": "Chat Node", + "data": {"model": "gpt-3.5-turbo", "prompt": "test prompt"}, + } + ], + "edges": [], + }, + ) + + @pytest.fixture def workflow_cycle_manager( real_app_generate_entity, real_workflow_system_variables, mock_workflow_execution_repository, mock_node_execution_repository, + real_workflow_entity, ): return WorkflowCycleManager( application_generate_entity=real_app_generate_entity, workflow_system_variables=real_workflow_system_variables, + workflow_entity=real_workflow_entity, workflow_execution_repository=mock_workflow_execution_repository, workflow_node_execution_repository=mock_node_execution_repository, )