test(workflow_service): Fix test

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21458/head
-LAN- 11 months ago
parent b450acf94b
commit 6a1110511f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -118,10 +118,11 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService: class WorkflowDraftVariableService:
_session: Session _session: Session
def __init__(self, session: Session) -> None: def __init__(self, session: Session, session_maker: sessionmaker | None = None) -> None:
self._session = session self._session = session
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker session_maker
) )
@ -254,7 +255,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None return None
node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id) node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None: if node_exec is None:
_logger.warning( _logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",

@ -58,8 +58,9 @@ class WorkflowService:
Workflow Service Workflow Service
""" """
def __init__(self): def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies.""" """Initialize WorkflowService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker session_maker

@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture @pytest.fixture
def workflow_setup(): def workflow_setup():
workflow_service = WorkflowService() mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session) session = MagicMock(spec=Session)
tenant_id = "test-tenant-id" tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id" workflow_id = "test-workflow-id"

@ -1,14 +1,14 @@
import dataclasses import dataclasses
import secrets import secrets
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.variables import StringSegment from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import ( from services.workflow_draft_variable_service import (
@ -108,7 +108,8 @@ class TestWorkflowDraftVariableService:
def test_reset_conversation_variable(self): def test_reset_conversation_variable(self):
"""Test resetting a conversation variable""" """Test resetting a conversation variable"""
mock_session = Mock(spec=Session) mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session) mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id) workflow = self._create_test_workflow(test_app_id)
@ -134,7 +135,8 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_no_execution_id(self): def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable""" """Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session) mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session) mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id) workflow = self._create_test_workflow(test_app_id)
@ -161,7 +163,16 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_missing_execution_record(self): def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist""" """Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session) mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session) mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id) workflow = self._create_test_workflow(test_app_id)
@ -171,11 +182,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable( variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
) )
mock_variable.editable = True
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
result = service._reset_node_var_or_sys_var(workflow, variable) result = service._reset_node_var_or_sys_var(workflow, variable)
@ -187,7 +194,21 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_valid_execution_record(self): def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution""" """Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session) mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session) mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id) workflow = self._create_test_workflow(test_app_id)
@ -197,16 +218,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable( variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
) )
mock_variable.editable = True
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
# Mock workflow methods # Mock workflow methods
mock_node_config = {"type": "test_node"} mock_node_config = {"type": "test_node"}
@ -227,7 +239,8 @@ class TestWorkflowDraftVariableService:
def test_reset_non_editable_system_variable_raises_error(self): def test_reset_non_editable_system_variable_raises_error(self):
"""Test that resetting a non-editable system variable raises an error""" """Test that resetting a non-editable system variable raises an error"""
mock_session = Mock(spec=Session) mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session) mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id() test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id) workflow = self._create_test_workflow(test_app_id)

@ -163,12 +163,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop # Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query # Mock execute method to handle both select and delete statements
mock_query = MagicMock() def mock_execute(stmt):
mock_session.query.return_value = mock_query mock_result = MagicMock()
mock_query.filter.return_value.delete.return_value = 2 # For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1) before_date = datetime(2023, 1, 1)
@ -181,8 +187,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert # Assert
assert result == 2 assert result == 2
mock_session.execute.assert_called_once() # One select call assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.query.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository): def test_delete_executions_by_app(self, repository):
@ -193,12 +198,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop # Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] execution_ids = ["id1", "id2"]
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query # Mock execute method to handle both select and delete statements
mock_query = MagicMock() def mock_execute(stmt):
mock_session.query.return_value = mock_query mock_result = MagicMock()
mock_query.filter.return_value.delete.return_value = 2 # For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act # Act
result = repository.delete_executions_by_app( result = repository.delete_executions_by_app(
@ -209,8 +220,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert # Assert
assert result == 2 assert result == 2
mock_session.execute.assert_called_once() # One select call assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.query.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository): def test_get_expired_executions_batch(self, repository):
@ -248,10 +258,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
mock_session = MagicMock(spec=Session) mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query # Mock the delete query result
mock_query = MagicMock() mock_result = MagicMock()
mock_session.query.return_value = mock_query mock_result.rowcount = 3
mock_query.filter.return_value.delete.return_value = 3 mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"] execution_ids = ["id1", "id2", "id3"]
@ -260,7 +270,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert # Assert
assert result == 3 assert result == 3
mock_session.query.assert_called_once() mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository): def test_delete_executions_by_ids_empty_list(self, repository):

@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService: class TestWorkflowService:
@pytest.fixture @pytest.fixture
def workflow_service(self): def workflow_service(self):
return WorkflowService() mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture @pytest.fixture
def mock_app(self): def mock_app(self):

Loading…
Cancel
Save