test(workflow_service): Fix test

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21458/head
-LAN- 11 months ago
parent b450acf94b
commit 6a1110511f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -118,10 +118,11 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
def __init__(self, session: Session, session_maker: sessionmaker | None = None) -> None:
self._session = session
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
@ -254,7 +255,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id)
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",

@ -58,9 +58,10 @@ class WorkflowService:
Workflow Service
"""
def __init__(self):
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies."""
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)

@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
@pytest.fixture
def workflow_setup():
workflow_service = WorkflowService()
mock_session_maker = MagicMock()
workflow_service = WorkflowService(mock_session_maker)
session = MagicMock(spec=Session)
tenant_id = "test-tenant-id"
workflow_id = "test-workflow-id"

@ -1,14 +1,14 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.variables import StringSegment
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from core.workflow.nodes.enums import NodeType
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from services.workflow_draft_variable_service import (
@ -108,7 +108,8 @@ class TestWorkflowDraftVariableService:
def test_reset_conversation_variable(self):
"""Test resetting a conversation variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -134,7 +135,8 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_no_execution_id(self):
"""Test resetting a node variable with no execution ID - should delete variable"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -161,7 +163,16 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_missing_execution_record(self):
"""Test resetting a node variable when execution record doesn't exist"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -171,11 +182,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Mock session.scalars to return None (no execution record found)
mock_scalars = Mock()
mock_scalars.first.return_value = None
mock_session.scalars.return_value = mock_scalars
mock_variable.editable = True
result = service._reset_node_var_or_sys_var(workflow, variable)
@ -187,7 +194,21 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -197,16 +218,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock session.scalars to return the execution record
mock_scalars = Mock()
mock_scalars.first.return_value = mock_execution
mock_session.scalars.return_value = mock_scalars
mock_variable.editable = True
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@ -227,7 +239,8 @@ class TestWorkflowDraftVariableService:
def test_reset_non_editable_system_variable_raises_error(self):
"""Test that resetting a non-editable system variable raises an error"""
mock_session = Mock(spec=Session)
service = WorkflowDraftVariableService(mock_session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)

@ -163,12 +163,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
@ -181,8 +187,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
@ -193,12 +198,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
@ -209,8 +220,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
@ -248,10 +258,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 3
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
@ -260,7 +270,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 3
mock_session.query.assert_called_once()
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):

@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService
class TestWorkflowService:
@pytest.fixture
def workflow_service(self):
return WorkflowService()
mock_session_maker = MagicMock()
return WorkflowService(mock_session_maker)
@pytest.fixture
def mock_app(self):

Loading…
Cancel
Save