test(test_workflow_cycle_manager): Refactors workflow execution handling in tests

Refactors workflow execution handling in tests

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20067/head
-LAN- 1 year ago
parent b727e4a84f
commit 198373219f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,7 +1,6 @@
import json
import time
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import Session
@ -12,8 +11,11 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities.node_execution_entities import NodeExecutionStatus
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from models.enums import CreatorUserRole
@ -59,10 +61,23 @@ def mock_node_execution_repository():
@pytest.fixture
def workflow_cycle_manager(mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository):
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
repo.get.return_value = None
return repo
@pytest.fixture
def workflow_cycle_manager(
mock_app_generate_entity,
mock_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
):
return WorkflowCycleManager(
application_generate_entity=mock_app_generate_entity,
workflow_system_variables=mock_workflow_system_variables,
workflow_execution_repository=mock_workflow_execution_repository,
workflow_node_execution_repository=mock_node_execution_repository,
)
@ -82,6 +97,7 @@ def mock_workflow():
workflow.type = "chat"
workflow.version = "1.0"
workflow.graph = json.dumps({"nodes": [], "edges": []})
workflow.graph_dict = {"nodes": [], "edges": []}
return workflow
@ -102,12 +118,16 @@ def mock_workflow_run():
def test_init(
workflow_cycle_manager, mock_app_generate_entity, mock_workflow_system_variables, mock_node_execution_repository
workflow_cycle_manager,
mock_app_generate_entity,
mock_workflow_system_variables,
mock_workflow_execution_repository,
mock_node_execution_repository,
):
"""Test initialization of WorkflowCycleManager"""
assert workflow_cycle_manager._workflow_run is None
assert workflow_cycle_manager._application_generate_entity == mock_app_generate_entity
assert workflow_cycle_manager._workflow_system_variables == mock_workflow_system_variables
assert workflow_cycle_manager._workflow_execution_repository == mock_workflow_execution_repository
assert workflow_cycle_manager._workflow_node_execution_repository == mock_node_execution_repository
@ -117,78 +137,83 @@ def test_handle_workflow_run_start(workflow_cycle_manager, mock_session, mock_wo
mock_session.scalar.side_effect = [mock_workflow, 5]
# Call the method
workflow_run = workflow_cycle_manager._handle_workflow_run_start(
workflow_execution = workflow_cycle_manager._handle_workflow_run_start(
session=mock_session,
workflow_id="test-workflow-id",
user_id="test-user-id",
created_by_role=CreatorUserRole.ACCOUNT,
)
# Verify the result
assert workflow_run.tenant_id == mock_workflow.tenant_id
assert workflow_run.app_id == mock_workflow.app_id
assert workflow_run.workflow_id == mock_workflow.id
assert workflow_run.sequence_number == 6 # max_sequence + 1
assert workflow_run.status == WorkflowRunStatus.RUNNING
assert workflow_run.created_by_role == CreatorUserRole.ACCOUNT
assert workflow_run.created_by == "test-user-id"
assert workflow_execution.workflow_id == mock_workflow.id
assert workflow_execution.sequence_number == 6 # max_sequence + 1
# Verify session.add was called
mock_session.add.assert_called_once_with(workflow_run)
# Verify the workflow_execution_repository.save was called
workflow_cycle_manager._workflow_execution_repository.save.assert_called_once_with(workflow_execution)
def test_handle_workflow_run_success(workflow_cycle_manager, mock_session, mock_workflow_run):
def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _handle_workflow_run_success method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Create a mock WorkflowExecution
mock_workflow_execution = MagicMock()
# Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution
# Call the method
result = workflow_cycle_manager._handle_workflow_run_success(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=100,
total_steps=5,
outputs={"answer": "test answer"},
)
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.SUCCEEDED
assert result.outputs == json.dumps({"answer": "test answer"})
assert result == mock_workflow_execution
assert result.status == WorkflowExecutionStatus.SUCCEEDED
assert result.outputs == {"answer": "test answer"}
assert result.total_tokens == 100
assert result.total_steps == 5
assert result.finished_at is not None
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_session, mock_workflow_run):
def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _handle_workflow_run_failed method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Create a mock WorkflowExecution
mock_workflow_execution = MagicMock()
# Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
# Call the method
result = workflow_cycle_manager._handle_workflow_run_failed(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=50,
total_steps=3,
status=WorkflowRunStatus.FAILED,
error="Test error message",
error_message="Test error message",
)
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.FAILED.value
assert result.error == "Test error message"
assert result == mock_workflow_execution
assert result.status == WorkflowExecutionStatus(WorkflowRunStatus.FAILED.value)
assert result.error_message == "Test error message"
assert result.total_tokens == 50
assert result.total_steps == 3
assert result.finished_at is not None
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _handle_node_execution_start method"""
# Create a mock WorkflowExecution
mock_workflow_execution = MagicMock()
mock_workflow_execution.id = "test-workflow-execution-id"
mock_workflow_execution.workflow_id = "test-workflow-id"
# Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
event.node_execution_id = "test-node-execution-id"
@ -208,43 +233,43 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_run):
# Call the method
result = workflow_cycle_manager._handle_node_execution_start(
workflow_run=mock_workflow_run,
workflow_execution_id=mock_workflow_execution.id,
event=event,
)
# Verify the result
# NodeExecution doesn't have tenant_id attribute, it's handled at repository level
# assert result.tenant_id == mock_workflow_run.tenant_id
# assert result.app_id == mock_workflow_run.app_id
assert result.workflow_id == mock_workflow_run.workflow_id
assert result.workflow_run_id == mock_workflow_run.id
assert result.workflow_id == mock_workflow_execution.workflow_id
assert result.workflow_run_id == mock_workflow_execution.id
assert result.node_execution_id == event.node_execution_id
assert result.node_id == event.node_id
assert result.node_type == event.node_type
assert result.title == event.node_data.title
assert result.status == WorkflowNodeExecutionStatus.RUNNING.value
# NodeExecution doesn't have created_by_role and created_by attributes, they're handled at repository level
# assert result.created_by_role == mock_workflow_run.created_by_role
# assert result.created_by == mock_workflow_run.created_by
assert result.status == NodeExecutionStatus.RUNNING
# Verify save was called
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(result)
def test_get_workflow_run(workflow_cycle_manager, mock_session, mock_workflow_run):
"""Test _get_workflow_run method"""
# Mock session.scalar to return the workflow run
mock_session.scalar.return_value = mock_workflow_run
def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _get_workflow_execution_or_raise_error method"""
# Create a mock WorkflowExecution
mock_workflow_execution = MagicMock(spec=WorkflowExecution)
# Mock the repository get method to return the mock execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_run(
session=mock_session,
workflow_run_id="test-workflow-run-id",
)
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
# Verify the result
assert result == mock_workflow_run
assert workflow_cycle_manager._workflow_run == mock_workflow_run
assert result == mock_workflow_execution
# Test error case
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
# Expect an error when execution is not found
with pytest.raises(ValueError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
def test_handle_workflow_node_execution_success(workflow_cycle_manager):
@ -278,15 +303,17 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
workflow_cycle_manager._workflow_node_execution_repository.save.assert_called_once_with(node_execution)
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_session, mock_workflow_run):
def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workflow_execution_repository):
"""Test _handle_workflow_run_partial_success method"""
# Mock _get_workflow_run to return the mock_workflow_run
with patch.object(workflow_cycle_manager, "_get_workflow_run", return_value=mock_workflow_run):
# Create a mock WorkflowExecution
mock_workflow_execution = MagicMock(spec=WorkflowExecution)
# Mock _get_workflow_execution_or_raise_error to return the mock_workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = mock_workflow_execution
# Call the method
result = workflow_cycle_manager._handle_workflow_run_partial_success(
session=mock_session,
workflow_run_id="test-workflow-run-id",
start_at=time.perf_counter() - 10, # 10 seconds ago
total_tokens=75,
total_steps=4,
outputs={"partial_answer": "test partial answer"},
@ -294,9 +321,9 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_sessio
)
# Verify the result
assert result == mock_workflow_run
assert result.status == WorkflowRunStatus.PARTIAL_SUCCEEDED.value
assert result.outputs == json.dumps({"partial_answer": "test partial answer"})
assert result == mock_workflow_execution
assert result.status == WorkflowExecutionStatus.PARTIAL_SUCCEEDED
assert result.outputs == {"partial_answer": "test partial answer"}
assert result.total_tokens == 75
assert result.total_steps == 4
assert result.exceptions_count == 2

Loading…
Cancel
Save