feat: add a abstract layer for WorkflowNodeExcetion (#18026)
parent
77fde04ef7
commit
6d9dd3109e
@ -0,0 +1,15 @@
|
||||
"""
|
||||
Repository interfaces for data access.
|
||||
|
||||
This package contains repository interfaces that define the contract
|
||||
for accessing and manipulating data, regardless of the underlying
|
||||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"RepositoryFactory",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,97 @@
|
||||
"""
|
||||
Repository factory for creating repository instances.
|
||||
|
||||
This module provides a simple factory interface for creating repository instances.
|
||||
It does not contain any implementation details or dependencies on specific repositories.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||
|
||||
# Type for workflow node execution factory function
|
||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||
|
||||
# Repository type literals
|
||||
RepositoryType = Literal["workflow_node_execution"]
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
"""
|
||||
Factory class for creating repository instances.
|
||||
|
||||
This factory delegates the actual repository creation to implementation-specific
|
||||
factory functions that are registered with the factory at runtime.
|
||||
"""
|
||||
|
||||
# Dictionary to store factory functions
|
||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_factory(cls, repository_type: RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for a specific repository type.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||
factory_func: A function that takes parameters and returns a repository instance
|
||||
"""
|
||||
cls._factory_functions[repository_type] = factory_func
|
||||
|
||||
@classmethod
|
||||
def _create_repository(cls, repository_type: RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Create a new repository instance with the provided parameters.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository to create
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the requested repository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the repository type
|
||||
"""
|
||||
if repository_type not in cls._factory_functions:
|
||||
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||
|
||||
# Use empty dict if params is None
|
||||
params = params or {}
|
||||
|
||||
return cls._factory_functions[repository_type](params)
|
||||
|
||||
@classmethod
|
||||
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for the workflow node execution repository.
|
||||
|
||||
Args:
|
||||
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||
"""
|
||||
cls._register_factory("workflow_node_execution", factory_func)
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls, params: Optional[Mapping[str, Any]] = None
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the WorkflowNodeExecutionRepository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||
"""
|
||||
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
||||
@ -0,0 +1,88 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Repository interface for WorkflowNodeExecution.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Extension for initializing repositories.
|
||||
|
||||
This extension registers repository implementations with the RepositoryFactory.
|
||||
"""
|
||||
|
||||
from dify_app import DifyApp
|
||||
from repositories.repository_registry import register_repositories
|
||||
|
||||
|
||||
def init_app(_app: DifyApp) -> None:
|
||||
"""
|
||||
Initialize repository implementations.
|
||||
|
||||
Args:
|
||||
_app: The Flask application instance (unused)
|
||||
"""
|
||||
register_repositories()
|
||||
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.repository package.
|
||||
"""
|
||||
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Registry for repository implementations.
|
||||
|
||||
This module is responsible for registering factory functions with the repository factory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage type constants
|
||||
STORAGE_TYPE_RDBMS = "rdbms"
|
||||
STORAGE_TYPE_HYBRID = "hybrid"
|
||||
|
||||
|
||||
def register_repositories() -> None:
|
||||
"""
|
||||
Register repository factory functions with the RepositoryFactory.
|
||||
|
||||
This function reads configuration settings to determine which repository
|
||||
implementations to register.
|
||||
"""
|
||||
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||
|
||||
# Check storage type and register appropriate implementation
|
||||
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||
# Register SQLAlchemy implementation for RDBMS storage
|
||||
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||
# Hybrid storage is not yet implemented
|
||||
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||
else:
|
||||
# Unknown storage type
|
||||
raise ValueError(
|
||||
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||
|
||||
This factory function creates a repository for the RDBMS storage type.
|
||||
|
||||
Args:
|
||||
params: Parameters for creating the repository, including:
|
||||
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||
- app_id: Optional. The application ID for filtering.
|
||||
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||
a new sessionmaker will be created using the global database engine.
|
||||
|
||||
Returns:
|
||||
A WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing
|
||||
"""
|
||||
# Extract required parameters
|
||||
tenant_id = params.get("tenant_id")
|
||||
if tenant_id is None:
|
||||
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||
|
||||
# Extract optional parameters
|
||||
app_id = params.get("app_id")
|
||||
|
||||
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||
session_factory = params.get("session_factory")
|
||||
if session_factory is None:
|
||||
# Create a sessionmaker using the same engine as the global db session
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Create and return the repository
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
@ -0,0 +1,9 @@
|
||||
"""
|
||||
WorkflowNodeExecution repository implementations.
|
||||
"""
|
||||
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,170 @@
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||
Each method creates its own session, handles the transaction, and commits changes
|
||||
to the database. This prevents long-running connections in the workflow core.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Optional app ID for filtering by application
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory)
|
||||
else:
|
||||
self._session_factory = session_factory
|
||||
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.add(execution)
|
||||
session.commit()
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
order_columns.append(desc(column))
|
||||
else:
|
||||
order_columns.append(asc(column))
|
||||
|
||||
if order_columns:
|
||||
stmt = stmt.order_by(*order_columns)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for repositories.
|
||||
"""
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for workflow_node_execution repositories.
|
||||
"""
|
||||
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""Create a mock SQLAlchemy session."""
|
||||
session = MagicMock(spec=Session)
|
||||
# Configure the session to be used as a context manager
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Configure the session factory to return the session
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session_factory.return_value = session
|
||||
return session, session_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(session):
|
||||
"""Create a repository instance with test data."""
|
||||
_, session_factory = session
|
||||
tenant_id = "test-tenant"
|
||||
app_id = "test-app"
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
|
||||
|
||||
def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
"""Test save method with existing tenant_id."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id is not changed and app_id is set
|
||||
assert execution.tenant_id == "existing-tenant"
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_node_execution_id method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
# Call method
|
||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
result = repository.get_running_executions("test-workflow-run-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_update(repository, session):
|
||||
"""Test update method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call update method
|
||||
repository.update(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.merge was called
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
Loading…
Reference in New Issue