From a97bc9a6b9160764b2ec04eb44ac546cff451014 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 21 May 2025 21:06:16 +0800 Subject: [PATCH] test(test_workflow_cycle_manager): Use real_app_generate_entity, real_workflow_system_variables, real_workflow and real_workflow_run Signed-off-by: -LAN- --- .../workflow/test_workflow_cycle_manager.py | 266 +++++++++++++----- 1 file changed, 196 insertions(+), 70 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 0f95b8f02b..cc5be0dcdf 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -11,8 +11,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_execution_entities import NodeExecutionStatus -from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus +from core.workflow.entities.workflow_execution_entities import WorkflowExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository @@ -21,27 +22,53 @@ from core.workflow.workflow_cycle_manager import WorkflowCycleManager from models.enums import CreatorUserRole from models.workflow import ( Workflow, - WorkflowNodeExecutionStatus, WorkflowRun, WorkflowRunStatus, ) @pytest.fixture -def mock_app_generate_entity(): - entity = MagicMock(spec=AdvancedChatAppGenerateEntity) - entity.inputs = {"query": "test query"} - entity.invoke_from = InvokeFrom.WEB_APP - # Create app_config as a separate mock - app_config = MagicMock() - app_config.tenant_id = "test-tenant-id" - app_config.app_id = "test-app-id" - entity.app_config = app_config +def real_app_generate_entity(): + from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig + from models.model import AppMode + + additional_features = AppAdditionalFeatures( + file_upload=None, + opening_statement=None, + suggested_questions=[], + suggested_questions_after_answer=False, + show_retrieve_source=False, + more_like_this=False, + speech_to_text=False, + text_to_speech=None, + trace_config=None, + ) + + app_config = WorkflowUIBasedAppConfig( + tenant_id="test-tenant-id", + app_id="test-app-id", + app_mode=AppMode.WORKFLOW, + additional_features=additional_features, + workflow_id="test-workflow-id", + ) + + entity = AdvancedChatAppGenerateEntity( + task_id="test-task-id", + app_config=app_config, + inputs={"query": "test query"}, + files=[], + user_id="test-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + query="test query", + conversation_id="test-conversation-id", + ) + return entity @pytest.fixture -def mock_workflow_system_variables(): +def real_workflow_system_variables(): return { SystemVariableKey.QUERY: "test query", SystemVariableKey.CONVERSATION_ID: "test-conversation-id", @@ -69,14 +96,14 @@ def mock_workflow_execution_repository(): @pytest.fixture def workflow_cycle_manager( - mock_app_generate_entity, - mock_workflow_system_variables, + real_app_generate_entity, + real_workflow_system_variables, mock_workflow_execution_repository, mock_node_execution_repository, ): return WorkflowCycleManager( - application_generate_entity=mock_app_generate_entity, - workflow_system_variables=mock_workflow_system_variables, + application_generate_entity=real_app_generate_entity, + workflow_system_variables=real_workflow_system_variables, workflow_execution_repository=mock_workflow_execution_repository, workflow_node_execution_repository=mock_node_execution_repository, ) @@ -89,52 +116,66 @@ def mock_session(): @pytest.fixture -def mock_workflow(): - workflow = MagicMock(spec=Workflow) +def real_workflow(): + workflow = Workflow() workflow.id = "test-workflow-id" workflow.tenant_id = "test-tenant-id" workflow.app_id = "test-app-id" workflow.type = "chat" workflow.version = "1.0" - workflow.graph = json.dumps({"nodes": [], "edges": []}) - workflow.graph_dict = {"nodes": [], "edges": []} + + graph_data = {"nodes": [], "edges": []} + workflow.graph = json.dumps(graph_data) + workflow.features = json.dumps({"file_upload": {"enabled": False}}) + workflow.created_by = "test-user-id" + workflow.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow._environment_variables = "{}" + workflow._conversation_variables = "{}" + return workflow @pytest.fixture -def mock_workflow_run(): - workflow_run = MagicMock(spec=WorkflowRun) +def real_workflow_run(): + workflow_run = WorkflowRun() workflow_run.id = "test-workflow-run-id" workflow_run.tenant_id = "test-tenant-id" workflow_run.app_id = "test-app-id" workflow_run.workflow_id = "test-workflow-id" + workflow_run.sequence_number = 1 + workflow_run.type = "chat" + workflow_run.triggered_from = "app-run" + workflow_run.version = "1.0" + workflow_run.graph = json.dumps({"nodes": [], "edges": []}) + workflow_run.inputs = json.dumps({"query": "test query"}) workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.outputs = json.dumps({"answer": "test answer"}) workflow_run.created_by_role = CreatorUserRole.ACCOUNT workflow_run.created_by = "test-user-id" workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.inputs_dict = {"query": "test query"} - workflow_run.outputs_dict = {"answer": "test answer"} + return workflow_run def test_init( workflow_cycle_manager, - mock_app_generate_entity, - mock_workflow_system_variables, + real_app_generate_entity, + real_workflow_system_variables, mock_workflow_execution_repository, mock_node_execution_repository, ): """Test initialization of WorkflowCycleManager""" - assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity - assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables + assert workflow_cycle_manager._application_generate_entity == real_app_generate_entity + assert workflow_cycle_manager._workflow_system_variables == real_workflow_system_variables assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository -def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_workflow): +def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, real_workflow): """Test _handle_workflow_run_start method""" # Mock session.scalar to return the workflow and max sequence - mock_session.scalar.side_effect = [mock_workflow, 5] + mock_session.scalar.side_effect = [real_workflow, 5] # Call the method workflow_execution = workflow_cycle_manager._handle_workflow_run_start( @@ -143,7 +184,7 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo ) # Verify the result - assert workflow_execution.workflow_id == mock_workflow.id + assert workflow_execution.workflow_id == real_workflow.id assert workflow_execution.sequence_number == 6 # max_sequence + 1 # Verify the workflow_execution_repository.save was called @@ -152,11 +193,24 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_success method""" - # Create a mock WorkflowExecution - mock_workflow_execution = MagicMock() + # Create a real WorkflowExecution + from datetime import UTC, datetime - # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType + + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution # Call the method result = workflow_cycle_manager._handle_workflow_run_success( @@ -167,7 +221,7 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu ) # Verify the result - assert result == mock_workflow_execution + assert result == workflow_execution assert result.status == WorkflowExecutionStatus.SUCCEEDED assert result.outputs == {"answer": "test answer"} assert result.total_tokens == 100 @@ -177,11 +231,24 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_failed method""" - # Create a mock WorkflowExecution - mock_workflow_execution = MagicMock() + # Create a real WorkflowExecution + from datetime import UTC, datetime - # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType + + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution # Mock get_running_executions to return an empty list workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] @@ -196,7 +263,7 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut ) # Verify the result - assert result == mock_workflow_execution + assert result == workflow_execution assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value) assert result.error_message == "Test error message" assert result.total_tokens == 50 @@ -206,13 +273,24 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_node_execution_start method""" - # Create a mock WorkflowExecution - mock_workflow_execution = MagicMock() - mock_workflow_execution.id = "test-workflow-execution-id" - mock_workflow_execution.workflow_id = "test-workflow-id" + # Create a real WorkflowExecution + from datetime import UTC, datetime + + from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType - # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + workflow_execution = WorkflowExecution( + id="test-workflow-execution-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution # Create a mock event event = MagicMock(spec=QueueNodeStartedEvent) @@ -233,13 +311,13 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu # Call the method result = workflow_cycle_manager._handle_node_execution_start( - workflow_execution_id=mock_workflow_execution.id, + workflow_execution_id=workflow_execution.id, event=event, ) # Verify the result - assert result.workflow_id == mock_workflow_execution.workflow_id - assert result.workflow_run_id == mock_workflow_execution.id + assert result.workflow_id == workflow_execution.workflow_id + assert result.workflow_run_id == workflow_execution.id assert result.node_execution_id == event.node_execution_id assert result.node_id == event.node_id assert result.node_type == event.node_type @@ -252,17 +330,30 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository): """Test _get_workflow_execution_or_raise_error method""" - # Create a mock WorkflowExecution - mock_workflow_execution = MagicMock(spec=WorkflowExecution) + # Create a real WorkflowExecution + from datetime import UTC, datetime + + from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType + + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) - # Mock the repository get method to return the mock execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + # Mock the repository get method to return the real execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution # Call the method result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") # Verify the result - assert result == mock_workflow_execution + assert result == workflow_execution # Test error case workflow_cycle_manager._workflow_execution_repository.get.return_value = None @@ -280,12 +371,23 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): event.inputs = {"input": "test input"} event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} - event.execution_metadata = {"metadata": "test metadata"} + event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} event.start_at = datetime.now(UTC).replace(tzinfo=None) - # Create a mock node execution - node_execution = MagicMock() - node_execution.node_execution_id = "test-node-execution-id" + # Create a real node execution + from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus + + node_execution = NodeExecution( + id="test-node-execution-record-id", + node_execution_id="test-node-execution-id", + workflow_id="test-workflow-id", + workflow_run_id="test-workflow-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + created_at=datetime.now(UTC).replace(tzinfo=None), + ) # Mock the repository to return the node execution workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution @@ -297,7 +399,7 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): # Verify the result assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED.value + assert result.status == NodeExecutionStatus.SUCCEEDED # Verify save was called workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) @@ -305,11 +407,24 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_partial_success method""" - # Create a mock WorkflowExecution - mock_workflow_execution = MagicMock(spec=WorkflowExecution) + # Create a real WorkflowExecution + from datetime import UTC, datetime + + from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType + + workflow_execution = WorkflowExecution( + id="test-workflow-run-id", + workflow_id="test-workflow-id", + workflow_version="1.0", + sequence_number=1, + type=WorkflowType.CHAT, + graph={"nodes": [], "edges": []}, + inputs={"query": "test query"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) - # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution - workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + # Mock _get_workflow_execution_or_raise_error to return the real workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution # Call the method result = workflow_cycle_manager._handle_workflow_run_partial_success( @@ -321,7 +436,7 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl ) # Verify the result - assert result == mock_workflow_execution + assert result == workflow_execution assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED assert result.outputs == {"partial_answer": "test partial answer"} assert result.total_tokens == 75 @@ -338,13 +453,24 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): event.inputs = {"input": "test input"} event.process_data = {"process": "test process"} event.outputs = {"output": "test output"} - event.execution_metadata = {"metadata": "test metadata"} + event.execution_metadata = {NodeRunMetadataKey.TOTAL_TOKENS: 100} event.start_at = datetime.now(UTC).replace(tzinfo=None) event.error = "Test error message" - # Create a mock node execution - node_execution = MagicMock() - node_execution.node_execution_id = "test-node-execution-id" + # Create a real node execution + from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus + + node_execution = NodeExecution( + id="test-node-execution-record-id", + node_execution_id="test-node-execution-id", + workflow_id="test-workflow-id", + workflow_run_id="test-workflow-run-id", + index=1, + node_id="test-node-id", + node_type=NodeType.LLM, + title="Test Node", + created_at=datetime.now(UTC).replace(tzinfo=None), + ) # Mock the repository to return the node execution workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution @@ -356,7 +482,7 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager): # Verify the result assert result == node_execution - assert result.status == WorkflowNodeExecutionStatus.FAILED.value + assert result.status == NodeExecutionStatus.FAILED assert result.error == "Test error message" # Verify save was called