From 6a1110511f9d5f7da307b1db3a962753f42141bf Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 27 Jun 2025 13:38:04 +0800 Subject: [PATCH] test(workflow_service): Fix test Signed-off-by: -LAN- --- .../workflow_draft_variable_service.py | 9 +-- api/services/workflow_service.py | 5 +- .../workflow/test_workflow_deletion.py | 3 +- .../test_workflow_draft_variable_service.py | 57 ++++++++++++------- ...kflow_node_execution_service_repository.py | 48 +++++++++------- .../workflow/test_workflow_service.py | 3 +- 6 files changed, 76 insertions(+), 49 deletions(-) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 0cb8c5574b..f2c52be9cf 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -118,10 +118,11 @@ class DraftVarLoader(VariableLoader): class WorkflowDraftVariableService: _session: Session - def __init__(self, session: Session) -> None: + def __init__(self, session: Session, session_maker: sessionmaker | None = None) -> None: self._session = session - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) - self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -254,7 +255,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id) + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 8122505592..0149d50346 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -58,9 +58,10 @@ class WorkflowService: Workflow Service """ - def __init__(self): + def __init__(self, session_maker: sessionmaker | None = None): """Initialize WorkflowService with repository dependencies.""" - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py index 223020c2c5..2c87eaf805 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py @@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE @pytest.fixture def workflow_setup(): - workflow_service = WorkflowService() + mock_session_maker = MagicMock() + workflow_service = WorkflowService(mock_session_maker) session = MagicMock(spec=Session) tenant_id = "test-tenant-id" workflow_id = "test-workflow-id" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index c5c9cf1050..929e2cd6b8 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -1,14 +1,14 @@ import dataclasses import secrets from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session from core.variables import StringSegment from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes import NodeType +from core.workflow.nodes.enums import NodeType from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable from services.workflow_draft_variable_service import ( @@ -108,7 +108,8 @@ class TestWorkflowDraftVariableService: def test_reset_conversation_variable(self): """Test resetting a conversation variable""" mock_session = Mock(spec=Session) - service = WorkflowDraftVariableService(mock_session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -134,7 +135,8 @@ class TestWorkflowDraftVariableService: def test_reset_node_variable_with_no_execution_id(self): """Test resetting a node variable with no execution ID - should delete variable""" mock_session = Mock(spec=Session) - service = WorkflowDraftVariableService(mock_session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -161,7 +163,16 @@ class TestWorkflowDraftVariableService: def test_reset_node_variable_with_missing_execution_record(self): """Test resetting a node variable when execution record doesn't exist""" mock_session = Mock(spec=Session) - service = WorkflowDraftVariableService(mock_session) + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + service = WorkflowDraftVariableService(mock_session, mock_session_maker) + + # Mock the repository to return None (no execution record found) + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = None + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -171,11 +182,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Mock session.scalars to return None (no execution record found) - mock_scalars = Mock() - mock_scalars.first.return_value = None - mock_session.scalars.return_value = mock_scalars + mock_variable.editable = True result = service._reset_node_var_or_sys_var(workflow, variable) @@ -187,7 +194,21 @@ class TestWorkflowDraftVariableService: def test_reset_node_variable_with_valid_execution_record(self): """Test resetting a node variable with valid execution record - should restore from execution""" mock_session = Mock(spec=Session) - service = WorkflowDraftVariableService(mock_session) + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_session + mock_session_maker.return_value.__exit__.return_value = None + service = WorkflowDraftVariableService(mock_session, mock_session_maker) + + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.process_data_dict = {"test_var": "process_value"} + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -197,16 +218,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Create mock execution record - mock_execution = Mock(spec=WorkflowNodeExecutionModel) - mock_execution.process_data_dict = {"test_var": "process_value"} - mock_execution.outputs_dict = {"test_var": "output_value"} - - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + mock_variable.editable = True # Mock workflow methods mock_node_config = {"type": "test_node"} @@ -227,7 +239,8 @@ class TestWorkflowDraftVariableService: def test_reset_non_editable_system_variable_raises_error(self): """Test that resetting a non-editable system variable raises an error""" mock_session = Mock(spec=Session) - service = WorkflowDraftVariableService(mock_session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 96f9139804..32d2f8b7e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -163,12 +163,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] # Less than batch_size to trigger break - mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 2 + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute before_date = datetime(2023, 1, 1) @@ -181,8 +187,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - mock_session.execute.assert_called_once() # One select call - mock_session.query.assert_called_once() + assert mock_session.execute.call_count == 2 # One select call, one delete call mock_session.commit.assert_called_once() def test_delete_executions_by_app(self, repository): @@ -193,12 +198,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] - mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 2 + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute # Act result = repository.delete_executions_by_app( @@ -209,8 +220,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - mock_session.execute.assert_called_once() # One select call - mock_session.query.assert_called_once() + assert mock_session.execute.call_count == 2 # One select call, one delete call mock_session.commit.assert_called_once() def test_get_expired_executions_batch(self, repository): @@ -248,10 +258,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: mock_session = MagicMock(spec=Session) repository._session_maker.return_value.__enter__.return_value = mock_session - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 3 + # Mock the delete query result + mock_result = MagicMock() + mock_result.rowcount = 3 + mock_session.execute.return_value = mock_result execution_ids = ["id1", "id2", "id3"] @@ -260,7 +270,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 3 - mock_session.query.assert_called_once() + mock_session.execute.assert_called_once() mock_session.commit.assert_called_once() def test_delete_executions_by_ids_empty_list(self, repository): diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 13393668ea..9700cbaf0e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService class TestWorkflowService: @pytest.fixture def workflow_service(self): - return WorkflowService() + mock_session_maker = MagicMock() + return WorkflowService(mock_session_maker) @pytest.fixture def mock_app(self):