From 979c4affa1d01747dd053b320b338de53c4ef7b4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 27 Jun 2025 13:38:04 +0800 Subject: [PATCH] test(workflow_service): Fix test Signed-off-by: -LAN- --- .../workflow_draft_variable_service.py | 19 ++----- api/services/workflow_service.py | 5 +- .../test_workflow_draft_variable_service.py | 50 ++++++++++++------- ...kflow_node_execution_service_repository.py | 48 +++++++++++------- 4 files changed, 70 insertions(+), 52 deletions(-) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index bec69f399d..9d198b3641 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -118,20 +118,11 @@ class DraftVarLoader(VariableLoader): class WorkflowDraftVariableService: _session: Session - def __init__(self, session: Session) -> None: - """ - Initialize the WorkflowDraftVariableService with a SQLAlchemy session. - - Args: - session (Session): The SQLAlchemy session used to execute database queries. - The provided session must be bound to an `Engine` object, not a specific `Connection`. - - Raises: - AssertionError: If the provided session is not bound to an `Engine` object. - """ + def __init__(self, session: Session, session_maker: sessionmaker | None = None) -> None: self._session = session - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) - self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -264,7 +255,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id) + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 464e42ff98..677bc74237 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -59,9 +59,10 @@ class WorkflowService: Workflow Service """ - def __init__(self): + def __init__(self, session_maker: sessionmaker | None = None): """Initialize WorkflowService with repository dependencies.""" - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 8b1348b75b..f07a18bc32 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -1,5 +1,6 @@ import dataclasses import secrets +from unittest import mock from unittest.mock import MagicMock, Mock, patch import pytest @@ -119,7 +120,9 @@ class TestWorkflowDraftVariableService: def test_reset_conversation_variable(self, mock_session): """Test resetting a conversation variable""" - service = WorkflowDraftVariableService(mock_session) + mock_session = Mock(spec=Session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -144,7 +147,9 @@ class TestWorkflowDraftVariableService: def test_reset_node_variable_with_no_execution_id(self, mock_session): """Test resetting a node variable with no execution ID - should delete variable""" - service = WorkflowDraftVariableService(mock_session) + mock_session = Mock(spec=Session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -175,14 +180,17 @@ class TestWorkflowDraftVariableService: monkeypatch, ): """Test resetting a node variable when execution record doesn't exist""" - mock_repo_session = Mock(spec=Session) - + mock_session = Mock(spec=Session) mock_session_maker = MagicMock() # Mock the context manager protocol for sessionmaker - mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__enter__.return_value = mock_session mock_session_maker.return_value.__exit__.return_value = None - monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker) - service = WorkflowDraftVariableService(mock_session) + service = WorkflowDraftVariableService(mock_session, mock_session_maker) + + # Mock the repository to return None (no execution record found) + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = None + # Mock the repository to return None (no execution record found) service._api_node_execution_repo = Mock() @@ -196,7 +204,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - # Variable is editable by default from factory method + mock_variable.editable = True result = service._reset_node_var_or_sys_var(workflow, variable) @@ -212,16 +220,22 @@ class TestWorkflowDraftVariableService: monkeypatch, ): """Test resetting a node variable with valid execution record - should restore from execution""" - mock_repo_session = Mock(spec=Session) - + mock_session = Mock(spec=Session) mock_session_maker = MagicMock() # Mock the context manager protocol for sessionmaker - mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__enter__.return_value = mock_session mock_session_maker.return_value.__exit__.return_value = None - mock_session_maker = monkeypatch.setattr( - "services.workflow_draft_variable_service.sessionmaker", mock_session_maker - ) - service = WorkflowDraftVariableService(mock_session) + service = WorkflowDraftVariableService(mock_session, mock_session_maker) + + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.process_data_dict = {"test_var": "process_value"} + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution + # Create mock execution record mock_execution = Mock(spec=WorkflowNodeExecutionModel) @@ -239,7 +253,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - # Variable is editable by default from factory method + mock_variable.editable = True # Mock workflow methods mock_node_config = {"type": "test_node"} @@ -259,7 +273,9 @@ class TestWorkflowDraftVariableService: def test_reset_non_editable_system_variable_raises_error(self, mock_session): """Test that resetting a non-editable system variable raises an error""" - service = WorkflowDraftVariableService(mock_session) + mock_session = Mock(spec=Session) + mock_session_maker = Mock() + service = WorkflowDraftVariableService(mock_session, mock_session_maker) test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 96f9139804..32d2f8b7e0 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -163,12 +163,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] # Less than batch_size to trigger break - mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 2 + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute before_date = datetime(2023, 1, 1) @@ -181,8 +187,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - mock_session.execute.assert_called_once() # One select call - mock_session.query.assert_called_once() + assert mock_session.execute.call_count == 2 # One select call, one delete call mock_session.commit.assert_called_once() def test_delete_executions_by_app(self, repository): @@ -193,12 +198,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] - mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 2 + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute # Act result = repository.delete_executions_by_app( @@ -209,8 +220,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - mock_session.execute.assert_called_once() # One select call - mock_session.query.assert_called_once() + assert mock_session.execute.call_count == 2 # One select call, one delete call mock_session.commit.assert_called_once() def test_get_expired_executions_batch(self, repository): @@ -248,10 +258,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: mock_session = MagicMock(spec=Session) repository._session_maker.return_value.__enter__.return_value = mock_session - # Mock the delete query - mock_query = MagicMock() - mock_session.query.return_value = mock_query - mock_query.filter.return_value.delete.return_value = 3 + # Mock the delete query result + mock_result = MagicMock() + mock_result.rowcount = 3 + mock_session.execute.return_value = mock_result execution_ids = ["id1", "id2", "id3"] @@ -260,7 +270,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 3 - mock_session.query.assert_called_once() + mock_session.execute.assert_called_once() mock_session.commit.assert_called_once() def test_delete_executions_by_ids_empty_list(self, repository):