From 198373219f942c42ca39a05148876b7da14a7ba8 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Wed, 21 May 2025 20:44:30 +0800 Subject: [PATCH] test(test_workflow_cycle_manager): Refactors workflow execution handling in tests Refactors workflow execution handling in tests Signed-off-by: -LAN- --- .../workflow/test_workflow_cycle_manager.py | 243 ++++++++++-------- 1 file changed, 135 insertions(+), 108 deletions(-) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 94b9d3e2c6..0f95b8f02b 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -1,7 +1,6 @@ import json -import time from datetime import UTC, datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest from sqlalchemy.orm import Session @@ -12,8 +11,11 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.entities.node_execution_entities import NodeExecutionStatus +from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType +from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_cycle_manager import WorkflowCycleManager from models.enums import CreatorUserRole @@ -59,10 +61,23 @@ def mock_node_execution_repository(): @pytest.fixture -def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository): +def mock_workflow_execution_repository(): + repo = MagicMock(spec=WorkflowExecutionRepository) + repo.get.return_value = None + return repo + + +@pytest.fixture +def workflow_cycle_manager( + mock_app_generate_entity, + mock_workflow_system_variables, + mock_workflow_execution_repository, + mock_node_execution_repository, +): return WorkflowCycleManager( application_generate_entity=mock_app_generate_entity, workflow_system_variables=mock_workflow_system_variables, + workflow_execution_repository=mock_workflow_execution_repository, workflow_node_execution_repository=mock_node_execution_repository, ) @@ -82,6 +97,7 @@ def mock_workflow(): workflow.type = "chat" workflow.version = "1.0" workflow.graph = json.dumps({"nodes": [], "edges": []}) + workflow.graph_dict = {"nodes": [], "edges": []} return workflow @@ -102,12 +118,16 @@ def mock_workflow_run(): def test_init( - workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository + workflow_cycle_manager, + mock_app_generate_entity, + mock_workflow_system_variables, + mock_workflow_execution_repository, + mock_node_execution_repository, ): """Test initialization of WorkflowCycleManager""" - assert workflow_cycle_manager._workflow_run is None assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables + assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository @@ -117,78 +137,83 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo mock_session.scalar.side_effect = [mock_workflow, 5] # Call the method - workflow_run = workflow_cycle_manager._handle_workflow_run_start( + workflow_execution = workflow_cycle_manager._handle_workflow_run_start( session=mock_session, workflow_id="test-workflow-id", - user_id="test-user-id", - created_by_role=CreatorUserRole.ACCOUNT, ) # Verify the result - assert workflow_run.tenant_id == mock_workflow.tenant_id - assert workflow_run.app_id == mock_workflow.app_id - assert workflow_run.workflow_id == mock_workflow.id - assert workflow_run.sequence_number == 6 # max_sequence + 1 - assert workflow_run.status == WorkflowRunStatus.RUNNING - assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT - assert workflow_run.created_by == "test-user-id" + assert workflow_execution.workflow_id == mock_workflow.id + assert workflow_execution.sequence_number == 6 # max_sequence + 1 - # Verify session.add was called - mock_session.add.assert_called_once_with(workflow_run) + # Verify the workflow_execution_repository.save was called + workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution) -def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run): +def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_success method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Call the method - result = workflow_cycle_manager._handle_workflow_run_success( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=100, - total_steps=5, - outputs={"answer": "test answer"}, - ) - - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.SUCCEEDED - assert result.outputs == json.dumps({"answer": "test answer"}) - assert result.total_tokens == 100 - assert result.total_steps == 5 - assert result.finished_at is not None - - -def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run): + # Create a mock WorkflowExecution + mock_workflow_execution = MagicMock() + + # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + + # Call the method + result = workflow_cycle_manager._handle_workflow_run_success( + workflow_run_id="test-workflow-run-id", + total_tokens=100, + total_steps=5, + outputs={"answer": "test answer"}, + ) + + # Verify the result + assert result == mock_workflow_execution + assert result.status == WorkflowExecutionStatus.SUCCEEDED + assert result.outputs == {"answer": "test answer"} + assert result.total_tokens == 100 + assert result.total_steps == 5 + assert result.finished_at is not None + + +def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_failed method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Mock get_running_executions to return an empty list - workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] - - # Call the method - result = workflow_cycle_manager._handle_workflow_run_failed( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=50, - total_steps=3, - status=WorkflowRunStatus.FAILED, - error="Test error message", - ) - - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.FAILED.value - assert result.error == "Test error message" - assert result.total_tokens == 50 - assert result.total_steps == 3 - assert result.finished_at is not None - - -def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): + # Create a mock WorkflowExecution + mock_workflow_execution = MagicMock() + + # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + + # Mock get_running_executions to return an empty list + workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = [] + + # Call the method + result = workflow_cycle_manager._handle_workflow_run_failed( + workflow_run_id="test-workflow-run-id", + total_tokens=50, + total_steps=3, + status=WorkflowRunStatus.FAILED, + error_message="Test error message", + ) + + # Verify the result + assert result == mock_workflow_execution + assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value) + assert result.error_message == "Test error message" + assert result.total_tokens == 50 + assert result.total_steps == 3 + assert result.finished_at is not None + + +def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_node_execution_start method""" + # Create a mock WorkflowExecution + mock_workflow_execution = MagicMock() + mock_workflow_execution.id = "test-workflow-execution-id" + mock_workflow_execution.workflow_id = "test-workflow-id" + + # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + # Create a mock event event = MagicMock(spec=QueueNodeStartedEvent) event.node_execution_id = "test-node-execution-id" @@ -208,43 +233,43 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run): # Call the method result = workflow_cycle_manager._handle_node_execution_start( - workflow_run=mock_workflow_run, + workflow_execution_id=mock_workflow_execution.id, event=event, ) # Verify the result - # NodeExecution doesn't have tenant_id attribute, it's handled at repository level - # assert result.tenant_id == mock_workflow_run.tenant_id - # assert result.app_id == mock_workflow_run.app_id - assert result.workflow_id == mock_workflow_run.workflow_id - assert result.workflow_run_id == mock_workflow_run.id + assert result.workflow_id == mock_workflow_execution.workflow_id + assert result.workflow_run_id == mock_workflow_execution.id assert result.node_execution_id == event.node_execution_id assert result.node_id == event.node_id assert result.node_type == event.node_type assert result.title == event.node_data.title - assert result.status == WorkflowNodeExecutionStatus.RUNNING.value - # NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level - # assert result.created_by_role == mock_workflow_run.created_by_role - # assert result.created_by == mock_workflow_run.created_by + assert result.status == NodeExecutionStatus.RUNNING # Verify save was called workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result) -def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run): - """Test _get_workflow_run method""" - # Mock session.scalar to return the workflow run - mock_session.scalar.return_value = mock_workflow_run +def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository): + """Test _get_workflow_execution_or_raise_error method""" + # Create a mock WorkflowExecution + mock_workflow_execution = MagicMock(spec=WorkflowExecution) + + # Mock the repository get method to return the mock execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution # Call the method - result = workflow_cycle_manager._get_workflow_run( - session=mock_session, - workflow_run_id="test-workflow-run-id", - ) + result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id") # Verify the result - assert result == mock_workflow_run - assert workflow_cycle_manager._workflow_run == mock_workflow_run + assert result == mock_workflow_execution + + # Test error case + workflow_cycle_manager._workflow_execution_repository.get.return_value = None + + # Expect an error when execution is not found + with pytest.raises(ValueError): + workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id") def test_handle_workflow_node_execution_success(workflow_cycle_manager): @@ -278,29 +303,31 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager): workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution) -def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run): +def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository): """Test _handle_workflow_run_partial_success method""" - # Mock _get_workflow_run to return the mock_workflow_run - with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run): - # Call the method - result = workflow_cycle_manager._handle_workflow_run_partial_success( - session=mock_session, - workflow_run_id="test-workflow-run-id", - start_at=time.perf_counter() - 10, # 10 seconds ago - total_tokens=75, - total_steps=4, - outputs={"partial_answer": "test partial answer"}, - exceptions_count=2, - ) - - # Verify the result - assert result == mock_workflow_run - assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value - assert result.outputs == json.dumps({"partial_answer": "test partial answer"}) - assert result.total_tokens == 75 - assert result.total_steps == 4 - assert result.exceptions_count == 2 - assert result.finished_at is not None + # Create a mock WorkflowExecution + mock_workflow_execution = MagicMock(spec=WorkflowExecution) + + # Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution + workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution + + # Call the method + result = workflow_cycle_manager._handle_workflow_run_partial_success( + workflow_run_id="test-workflow-run-id", + total_tokens=75, + total_steps=4, + outputs={"partial_answer": "test partial answer"}, + exceptions_count=2, + ) + + # Verify the result + assert result == mock_workflow_execution + assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED + assert result.outputs == {"partial_answer": "test partial answer"} + assert result.total_tokens == 75 + assert result.total_steps == 4 + assert result.exceptions_count == 2 + assert result.finished_at is not None def test_handle_workflow_node_execution_failed(workflow_cycle_manager):