feat: Create a DifyAPIRepositoryFactory to handle workflow node execution operations out of core.
Signed-off-by: -LAN- <laipz8200@outlook.com>pull/21458/head
parent
733386bc7d
commit
b2b4049279
@ -0,0 +1,196 @@
|
||||
"""
|
||||
Service-layer repository protocol for WorkflowNodeExecutionModel operations.
|
||||
|
||||
This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
|
||||
that abstracts database queries currently done directly in service classes. This repository is
|
||||
specifically designed for service-layer needs and is separate from the core domain repository.
|
||||
|
||||
The service repository handles operations that require access to database-specific fields like
|
||||
tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
|
||||
This repository provides database access patterns specifically needed by service classes,
|
||||
handling queries that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Key responsibilities:
|
||||
- Manages database operations for workflow node executions
|
||||
- Handles multi-tenant data isolation
|
||||
- Provides batch processing capabilities
|
||||
- Supports execution lifecycle management
|
||||
|
||||
Implementation notes:
|
||||
- Returns database models directly (WorkflowNodeExecutionModel)
|
||||
- Handles tenant/app filtering automatically
|
||||
- Provides service-specific query patterns
|
||||
- Focuses on database operations without domain logic
|
||||
- Supports cleanup and maintenance operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method finds the latest execution of a specific node within a workflow,
|
||||
ordered by creation time. Used primarily for debugging and inspection purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method retrieves all node executions that belong to a specific workflow run,
|
||||
ordered by index in descending order for proper trace visualization.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method retrieves a specific execution by its unique identifier.
|
||||
Tenant filtering is optional for cases where the execution ID is globally unique.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
This method is used for cleanup operations to remove expired executions
|
||||
in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
This method is used when removing an app and all its related data.
|
||||
Executions are deleted in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
This method retrieves expired executions without deleting them,
|
||||
allowing the caller to backup the data before deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
This method deletes specific executions by their IDs,
|
||||
typically used after backing up the data.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,66 @@
|
||||
"""
|
||||
DifyAPI Repository Factory for creating repository instances.
|
||||
|
||||
This factory is specifically designed for DifyAPI repositories that handle
|
||||
service-layer operations with dependency injection patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||
"""
|
||||
Factory for creating DifyAPI repository instances based on configuration.
|
||||
|
||||
This factory handles the creation of repositories that are specifically designed
|
||||
for service-layer operations and use dependency injection with sessionmaker
|
||||
for better testability and separation of concerns.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_api_workflow_node_execution_repository(
|
||||
cls, session_maker: sessionmaker
|
||||
) -> DifyAPIWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
This repository is designed for service-layer operations and uses dependency injection
|
||||
with a sessionmaker for better testability and separation of concerns. It provides
|
||||
database access patterns specifically needed by service classes, handling queries
|
||||
that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker to inject for database session management.
|
||||
|
||||
Returns:
|
||||
Configured DifyAPIWorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
|
||||
# Service repository requires session_maker parameter
|
||||
cls._validate_constructor_signature(repository_class, ["session_maker"])
|
||||
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
@ -0,0 +1,286 @@
|
||||
"""
|
||||
SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
|
||||
|
||||
This module provides a concrete implementation of the service repository protocol
|
||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, desc, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
|
||||
This repository provides service-layer database operations for WorkflowNodeExecutionModel
|
||||
using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
|
||||
protocol with the following features:
|
||||
|
||||
- Multi-tenancy data isolation through tenant_id filtering
|
||||
- Direct database model operations without domain conversion
|
||||
- Batch processing for efficient large-scale operations
|
||||
- Optimized query patterns for common access patterns
|
||||
- Dependency injection for better testability and maintainability
|
||||
- Session management and transaction handling with proper cleanup
|
||||
- Maintenance operations for data lifecycle management
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker for creating database sessions
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method replicates the query pattern from WorkflowService.get_node_last_run()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.index))
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method replicates the query pattern from WorkflowDraftVariableService
|
||||
and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
|
||||
|
||||
# Add tenant filtering if provided
|
||||
if tenant_id is not None:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
if not execution_ids:
|
||||
return 0
|
||||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
@ -0,0 +1,278 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
assert hasattr(repository, "get_node_last_execution")
|
||||
assert hasattr(repository, "get_executions_by_workflow_run")
|
||||
assert hasattr(repository, "get_execution_by_id")
|
||||
|
||||
# Verify methods are callable
|
||||
assert callable(repository.get_node_last_execution)
|
||||
assert callable(repository.get_executions_by_workflow_run)
|
||||
assert callable(repository.get_execution_by_id)
|
||||
assert callable(repository.delete_expired_executions)
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
|
||||
|
||||
# Mock the delete query
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value.delete.return_value = 2
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
mock_session.execute.assert_called_once() # One select call
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
|
||||
|
||||
# Mock the delete query
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value.delete.return_value = 2
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
mock_session.execute.assert_called_once() # One select call
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query
|
||||
mock_query = MagicMock()
|
||||
mock_session.query.return_value = mock_query
|
||||
mock_query.filter.return_value.delete.return_value = 3
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.query.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
Loading…
Reference in New Issue