test(workflow_service): Fix test

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 11 months ago
parent 8c945b7a08
commit 979c4affa1
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -118,20 +118,11 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
Args:
session (Session): The SQLAlchemy session used to execute database queries.
The provided session must be bound to an `Engine` object, not a specific `Connection`.
Raises:
AssertionError: If the provided session is not bound to an `Engine` object.
"""
def __init__(self, session: Session, session_maker: sessionmaker | None = None) -> None:
self._session = session
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
@ -264,7 +255,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id)
node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",

@ -59,8 +59,9 @@ class WorkflowService:
Workflow Service
"""
def __init__(self):
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize WorkflowService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker

@ -1,5 +1,6 @@
import dataclasses
import secrets
from unittest import mock
from unittest.mock import MagicMock, Mock, patch
import pytest
@ -119,7 +120,9 @@ class TestWorkflowDraftVariableService:
def test_reset_conversation_variable(self, mock_session):
"""Test resetting a conversation variable"""
service = WorkflowDraftVariableService(mock_session)
mock_session = Mock(spec=Session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -144,7 +147,9 @@ class TestWorkflowDraftVariableService:
def test_reset_node_variable_with_no_execution_id(self, mock_session):
"""Test resetting a node variable with no execution ID - should delete variable"""
service = WorkflowDraftVariableService(mock_session)
mock_session = Mock(spec=Session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)
@ -175,14 +180,17 @@ class TestWorkflowDraftVariableService:
monkeypatch,
):
"""Test resetting a node variable when execution record doesn't exist"""
mock_repo_session = Mock(spec=Session)
mock_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker)
service = WorkflowDraftVariableService(mock_session)
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = None
# Mock the repository to return None (no execution record found)
service._api_node_execution_repo = Mock()
@ -196,7 +204,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
mock_variable.editable = True
result = service._reset_node_var_or_sys_var(workflow, variable)
@ -212,16 +220,22 @@ class TestWorkflowDraftVariableService:
monkeypatch,
):
"""Test resetting a node variable with valid execution record - should restore from execution"""
mock_repo_session = Mock(spec=Session)
mock_session = Mock(spec=Session)
mock_session_maker = MagicMock()
# Mock the context manager protocol for sessionmaker
mock_session_maker.return_value.__enter__.return_value = mock_repo_session
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session_maker.return_value.__exit__.return_value = None
mock_session_maker = monkeypatch.setattr(
"services.workflow_draft_variable_service.sessionmaker", mock_session_maker
)
service = WorkflowDraftVariableService(mock_session)
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
mock_execution.process_data_dict = {"test_var": "process_value"}
mock_execution.outputs_dict = {"test_var": "output_value"}
# Mock the repository to return the execution record
service._api_node_execution_repo = Mock()
service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution
# Create mock execution record
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
@ -239,7 +253,7 @@ class TestWorkflowDraftVariableService:
variable = WorkflowDraftVariable.new_node_variable(
app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id"
)
# Variable is editable by default from factory method
mock_variable.editable = True
# Mock workflow methods
mock_node_config = {"type": "test_node"}
@ -259,7 +273,9 @@ class TestWorkflowDraftVariableService:
def test_reset_non_editable_system_variable_raises_error(self, mock_session):
"""Test that resetting a non-editable system variable raises an error"""
service = WorkflowDraftVariableService(mock_session)
mock_session = Mock(spec=Session)
mock_session_maker = Mock()
service = WorkflowDraftVariableService(mock_session, mock_session_maker)
test_app_id = self._get_test_app_id()
workflow = self._create_test_workflow(test_app_id)

@ -163,12 +163,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
before_date = datetime(2023, 1, 1)
@ -181,8 +187,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
@ -193,12 +198,18 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
# Mock execute method to handle both select and delete statements
def mock_execute(stmt):
mock_result = MagicMock()
# For select statements, return execution IDs
if hasattr(stmt, "limit"): # This is our select statement
mock_result.scalars.return_value.all.return_value = execution_ids
else: # This is our delete statement
mock_result.rowcount = 2
return mock_result
mock_session.execute.side_effect = mock_execute
# Act
result = repository.delete_executions_by_app(
@ -209,8 +220,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
assert mock_session.execute.call_count == 2 # One select call, one delete call
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
@ -248,10 +258,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 3
# Mock the delete query result
mock_result = MagicMock()
mock_result.rowcount = 3
mock_session.execute.return_value = mock_result
execution_ids = ["id1", "id2", "id3"]
@ -260,7 +270,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
# Assert
assert result == 3
mock_session.query.assert_called_once()
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):

Loading…
Cancel
Save