feat: Create a DifyAPIRepositoryFactory to handle workflow node execution operations out of core.

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21458/head
-LAN- 11 months ago
parent 733386bc7d
commit b2b4049279
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -542,16 +542,22 @@ class RepositoryConfig(BaseSettings):
Configuration for repository implementations
"""
WORKFLOW_EXECUTION_REPOSITORY: str = Field(
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
"Specify as a module path",
default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
)
class AuthConfig(BaseSettings):
"""

@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
@ -182,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@ -259,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@ -342,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,

@ -23,7 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -155,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@ -305,14 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@ -387,14 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,

@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
@ -123,7 +123,7 @@ class LangFuseDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@ -145,7 +145,7 @@ class LangSmithDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@ -160,7 +160,7 @@ class OpikDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@ -144,7 +144,7 @@ class WeaveDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,

@ -5,11 +5,11 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.factory import RepositoryFactory, RepositoryImportError
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"RepositoryFactory",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

@ -28,7 +28,7 @@ class RepositoryImportError(Exception):
pass
class RepositoryFactory:
class DifyCoreRepositoryFactory:
"""
Factory for creating repository instances based on configuration.
@ -143,7 +143,7 @@ class RepositoryFactory:
Raises:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.WORKFLOW_EXECUTION_REPOSITORY
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
try:
@ -189,7 +189,7 @@ class RepositoryFactory:
Raises:
RepositoryImportError: If the configured repository cannot be created
"""
class_path = dify_config.WORKFLOW_NODE_EXECUTION_REPOSITORY
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
try:

@ -0,0 +1,196 @@
"""
Service-layer repository protocol for WorkflowNodeExecutionModel operations.
This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
that abstracts database queries currently done directly in service classes. This repository is
specifically designed for service-layer needs and is separate from the core domain repository.
The service repository handles operations that require access to database-specific fields like
tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
"""
from abc import abstractmethod
from collections.abc import Sequence
from datetime import datetime
from typing import Optional, Protocol
from models.workflow import WorkflowNodeExecutionModel
class DifyAPIWorkflowNodeExecutionRepository(Protocol):
"""
Protocol for service-layer operations on WorkflowNodeExecutionModel.
This repository provides database access patterns specifically needed by service classes,
handling queries that involve database-specific fields and multi-tenancy concerns.
Key responsibilities:
- Manages database operations for workflow node executions
- Handles multi-tenant data isolation
- Provides batch processing capabilities
- Supports execution lifecycle management
Implementation notes:
- Returns database models directly (WorkflowNodeExecutionModel)
- Handles tenant/app filtering automatically
- Provides service-specific query patterns
- Focuses on database operations without domain logic
- Supports cleanup and maintenance operations
"""
@abstractmethod
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get the most recent execution for a specific node.
This method finds the latest execution of a specific node within a workflow,
ordered by creation time. Used primarily for debugging and inspection purposes.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
...
@abstractmethod
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
This method retrieves all node executions that belong to a specific workflow run,
ordered by index in descending order for proper trace visualization.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
"""
...
@abstractmethod
def get_execution_by_id(
self,
execution_id: str,
tenant_id: Optional[str] = None,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get a workflow node execution by its ID.
This method retrieves a specific execution by its unique identifier.
Tenant filtering is optional for cases where the execution ID is globally unique.
Args:
execution_id: The execution identifier
tenant_id: Optional tenant identifier for additional filtering
Returns:
The WorkflowNodeExecutionModel if found, or None if not found
"""
...
@abstractmethod
def delete_expired_executions(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> int:
"""
Delete workflow node executions that are older than the specified date.
This method is used for cleanup operations to remove expired executions
in batches to avoid overwhelming the database.
Args:
tenant_id: The tenant identifier
before_date: Delete executions created before this date
batch_size: Maximum number of executions to delete in one batch
Returns:
The number of executions deleted
"""
...
@abstractmethod
def delete_executions_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow node executions for a specific app.
This method is used when removing an app and all its related data.
Executions are deleted in batches to avoid overwhelming the database.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
batch_size: Maximum number of executions to delete in one batch
Returns:
The total number of executions deleted
"""
...
@abstractmethod
def get_expired_executions_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get a batch of expired workflow node executions for backup purposes.
This method retrieves expired executions without deleting them,
allowing the caller to backup the data before deletion.
Args:
tenant_id: The tenant identifier
before_date: Get executions created before this date
batch_size: Maximum number of executions to retrieve
Returns:
A sequence of WorkflowNodeExecutionModel instances
"""
...
@abstractmethod
def delete_executions_by_ids(
self,
execution_ids: Sequence[str],
) -> int:
"""
Delete workflow node executions by their IDs.
This method deletes specific executions by their IDs,
typically used after backing up the data.
Args:
execution_ids: List of execution IDs to delete
Returns:
The number of executions deleted
"""
...

@ -0,0 +1,66 @@
"""
DifyAPI Repository Factory for creating repository instances.
This factory is specifically designed for DifyAPI repositories that handle
service-layer operations with dependency injection patterns.
"""
import logging
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
"""
Factory for creating DifyAPI repository instances based on configuration.
This factory handles the creation of repositories that are specifically designed
for service-layer operations and use dependency injection with sessionmaker
for better testability and separation of concerns.
"""
@classmethod
def create_api_workflow_node_execution_repository(
cls, session_maker: sessionmaker
) -> DifyAPIWorkflowNodeExecutionRepository:
"""
Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
This repository is designed for service-layer operations and uses dependency injection
with a sessionmaker for better testability and separation of concerns. It provides
database access patterns specifically needed by service classes, handling queries
that involve database-specific fields and multi-tenancy concerns.
Args:
session_maker: SQLAlchemy sessionmaker to inject for database session management.
Returns:
Configured DifyAPIWorkflowNodeExecutionRepository instance
Raises:
RepositoryImportError: If the configured repository cannot be imported or instantiated
"""
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
# Service repository requires session_maker parameter
cls._validate_constructor_signature(repository_class, ["session_maker"])
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
except Exception as e:
logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
) from e

@ -0,0 +1,286 @@
"""
SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
This module provides a concrete implementation of the service repository protocol
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
"""
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
from sqlalchemy import delete, desc, select
from sqlalchemy.orm import Session, sessionmaker
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
This repository provides service-layer database operations for WorkflowNodeExecutionModel
using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
protocol with the following features:
- Multi-tenancy data isolation through tenant_id filtering
- Direct database model operations without domain conversion
- Batch processing for efficient large-scale operations
- Optimized query patterns for common access patterns
- Dependency injection for better testability and maintainability
- Session management and transaction handling with proper cleanup
- Maintenance operations for data lifecycle management
- Thread-safe database operations using session-per-request pattern
"""
def __init__(self, session_maker: sessionmaker[Session]):
"""
Initialize the repository with a sessionmaker.
Args:
session_maker: SQLAlchemy sessionmaker for creating database sessions
"""
self._session_maker = session_maker
def get_node_last_execution(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
node_id: str,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get the most recent execution for a specific node.
This method replicates the query pattern from WorkflowService.get_node_last_run()
using SQLAlchemy 2.0 style syntax.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.node_id == node_id,
)
.order_by(desc(WorkflowNodeExecutionModel.created_at))
.limit(1)
)
with self._session_maker() as session:
return session.scalar(stmt)
def get_executions_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get all node executions for a specific workflow run.
This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
using SQLAlchemy 2.0 style syntax.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(desc(WorkflowNodeExecutionModel.index))
)
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def get_execution_by_id(
self,
execution_id: str,
tenant_id: Optional[str] = None,
) -> Optional[WorkflowNodeExecutionModel]:
"""
Get a workflow node execution by its ID.
This method replicates the query pattern from WorkflowDraftVariableService
and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
Args:
execution_id: The execution identifier
tenant_id: Optional tenant identifier for additional filtering
Returns:
The WorkflowNodeExecutionModel if found, or None if not found
"""
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
# Add tenant filtering if provided
if tenant_id is not None:
stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
with self._session_maker() as session:
return session.scalar(stmt)
def delete_expired_executions(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> int:
"""
Delete workflow node executions that are older than the specified date.
Args:
tenant_id: The tenant identifier
before_date: Delete executions created before this date
batch_size: Maximum number of executions to delete in one batch
Returns:
The number of executions deleted
"""
total_deleted = 0
while True:
with self._session_maker() as session:
# Find executions to delete in batches
stmt = (
select(WorkflowNodeExecutionModel.id)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at < before_date,
)
.limit(batch_size)
)
execution_ids = session.execute(stmt).scalars().all()
if not execution_ids:
break
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
session.commit()
total_deleted += result.rowcount
# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
break
return total_deleted
def delete_executions_by_app(
self,
tenant_id: str,
app_id: str,
batch_size: int = 1000,
) -> int:
"""
Delete all workflow node executions for a specific app.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
batch_size: Maximum number of executions to delete in one batch
Returns:
The total number of executions deleted
"""
total_deleted = 0
while True:
with self._session_maker() as session:
# Find executions to delete in batches
stmt = (
select(WorkflowNodeExecutionModel.id)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
)
.limit(batch_size)
)
execution_ids = session.execute(stmt).scalars().all()
if not execution_ids:
break
# Delete the batch
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(delete_stmt)
session.commit()
total_deleted += result.rowcount
# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
break
return total_deleted
def get_expired_executions_batch(
self,
tenant_id: str,
before_date: datetime,
batch_size: int = 1000,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get a batch of expired workflow node executions for backup purposes.
Args:
tenant_id: The tenant identifier
before_date: Get executions created before this date
batch_size: Maximum number of executions to retrieve
Returns:
A sequence of WorkflowNodeExecutionModel instances
"""
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at < before_date,
)
.limit(batch_size)
)
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def delete_executions_by_ids(
self,
execution_ids: Sequence[str],
) -> int:
"""
Delete workflow node executions by their IDs.
Args:
execution_ids: List of execution IDs to delete
Returns:
The number of executions deleted
"""
if not execution_ids:
return 0
with self._session_maker() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = session.execute(stmt)
session.commit()
return result.rowcount

@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
@ -14,7 +14,8 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@ -105,48 +106,52 @@ class ClearFreePlanTenantExpiredLogs:
)
)
while True:
with Session(db.engine).no_autoflush as session:
workflow_node_executions = (
session.query(WorkflowNodeExecutionModel)
.filter(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.created_at
< datetime.datetime.now() - datetime.timedelta(days=days),
)
.limit(batch)
.all()
)
# Process expired workflow node executions with backup
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
before_date = datetime.datetime.now() - datetime.timedelta(days=days)
total_deleted = 0
if len(workflow_node_executions) == 0:
break
while True:
# Get a batch of expired executions for backup
workflow_node_executions = node_execution_repo.get_expired_executions_batch(
tenant_id=tenant_id,
before_date=before_date,
batch_size=batch,
)
# save workflow node executions
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(workflow_node_executions),
).encode("utf-8"),
)
if len(workflow_node_executions) == 0:
break
# Save workflow node executions to storage
storage.save(
f"free_plan_tenant_expired_logs/"
f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
f"-{time.time()}.json",
json.dumps(
jsonable_encoder(workflow_node_executions),
).encode("utf-8"),
)
workflow_node_execution_ids = [
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
]
# Extract IDs for deletion
workflow_node_execution_ids = [
workflow_node_execution.id for workflow_node_execution in workflow_node_executions
]
# delete workflow node executions
session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
).delete(synchronize_session=False)
session.commit()
# Delete the backed up executions
deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
total_deleted += deleted_count
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
f" workflow node executions for tenant {tenant_id}"
)
click.echo(
click.style(
f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
f" workflow node executions for tenant {tenant_id}"
)
)
# If we got fewer than the batch size, we're done
if len(workflow_node_executions) < batch:
break
while True:
with Session(db.engine).no_autoflush as session:

@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar
from sqlalchemy import Engine, orm, select
from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
@ -21,11 +21,13 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
from core.workflow.variable_loader import VariableLoader
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation
from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
from repositories.factory import DifyAPIRepositoryFactory
_logger = logging.getLogger(__name__)
@ -118,6 +120,10 @@ class WorkflowDraftVariableService:
def __init__(self, session: Session) -> None:
self._session = session
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
@ -248,8 +254,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
node_exec = self._session.scalars(query).first()
node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",

@ -2,7 +2,7 @@ import threading
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import desc, select
from sqlalchemy.orm import sessionmaker
import contexts
from extensions.ext_database import db
@ -15,9 +15,17 @@ from models import (
WorkflowRun,
WorkflowRunTriggeredFrom,
)
from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunService:
def __init__(self):
"""Initialize WorkflowRunService with repository dependencies."""
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
@ -138,17 +146,11 @@ class WorkflowRunService:
# Get tenant_id from user
tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if tenant_id is None:
raise ValueError("User tenant_id cannot be None")
# Use SQLAlchemy 2.0 style query directly
stmt = (
select(WorkflowNodeExecutionModel)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_model.id,
WorkflowNodeExecutionModel.workflow_run_id == run_id,
)
.order_by(desc(WorkflowNodeExecutionModel.index))
return self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=tenant_id,
app_id=app_model.id,
workflow_run_id=run_id,
)
workflow_node_executions = db.session.execute(stmt).scalars().all()
return workflow_node_executions

@ -7,13 +7,13 @@ from typing import Any, Optional
from uuid import uuid4
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import RepositoryFactory
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
@ -41,6 +41,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
@ -57,21 +58,31 @@ class WorkflowService:
Workflow Service
"""
def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
# TODO(QuantumGhost): This query is not fully covered by index.
criteria = (
WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id,
WorkflowNodeExecutionModel.app_id == app_model.id,
WorkflowNodeExecutionModel.workflow_id == workflow.id,
WorkflowNodeExecutionModel.node_id == node_id,
def __init__(self):
"""Initialize WorkflowService with repository dependencies."""
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
node_exec = (
db.session.query(WorkflowNodeExecutionModel)
.filter(*criteria)
.order_by(WorkflowNodeExecutionModel.created_at.desc())
.first()
def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node.
Args:
app_model: The application model
workflow: The workflow model
node_id: The node identifier
Returns:
The most recent WorkflowNodeExecutionModel for the node, or None if not found
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
workflow_id=workflow.id,
node_id=node_id,
)
return node_exec
def is_workflow_exist(self, app_model: App) -> bool:
return (
@ -396,7 +407,7 @@ class WorkflowService:
node_execution.workflow_id = draft_workflow.id
# Create repository and save the node execution
repository = RepositoryFactory.create_workflow_node_execution_repository(
repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=db.engine,
user=account,
app_id=app_model.id,
@ -404,8 +415,9 @@ class WorkflowService:
)
repository.save(node_execution)
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == node_execution.id)
workflow_node_execution = db.session.execute(stmt).scalar_one()
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
if workflow_node_execution is None:
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
@ -418,6 +430,7 @@ class WorkflowService:
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
session.commit()
return workflow_node_execution
def run_free_workflow_node(
@ -429,7 +442,7 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
@ -441,7 +454,7 @@ class WorkflowService:
node_id=node_id,
)
return workflow_node_execution
return node_execution
def _handle_node_run_result(
self,

@ -6,6 +6,7 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models import (
@ -31,7 +32,8 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@ -201,18 +203,18 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str):
db.session.query(WorkflowNodeExecutionModel).filter(
WorkflowNodeExecutionModel.id == workflow_node_execution_id
).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_node_execution,
"workflow node execution",
"""Delete all workflow node executions for an app using the service repository."""
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
deleted_count = node_execution_repo.delete_executions_by_app(
tenant_id=tenant_id,
app_id=app_id,
batch_size=1000,
)
logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}")
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):

@ -12,7 +12,7 @@ from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repositories.factory import RepositoryFactory, RepositoryImportError
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models import Account, EndUser
@ -27,25 +27,25 @@ class TestRepositoryFactory:
"""Test successful class import."""
# Test importing a real class
class_path = "unittest.mock.MagicMock"
result = RepositoryFactory._import_class(class_path)
result = DifyCoreRepositoryFactory._import_class(class_path)
assert result is MagicMock
def test_import_class_invalid_path(self):
"""Test import with invalid module path."""
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._import_class("invalid.module.path")
DifyCoreRepositoryFactory._import_class("invalid.module.path")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_invalid_class_name(self):
"""Test import with invalid class name."""
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._import_class("unittest.mock.NonExistentClass")
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_malformed_path(self):
"""Test import with malformed path (no dots)."""
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._import_class("invalidpath")
DifyCoreRepositoryFactory._import_class("invalidpath")
assert "Cannot import repository class" in str(exc_info.value)
def test_validate_repository_interface_success(self):
@ -68,7 +68,7 @@ class TestRepositoryFactory:
pass
# Should not raise an exception
RepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods."""
@ -89,7 +89,7 @@ class TestRepositoryFactory:
pass
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value)
@ -101,7 +101,7 @@ class TestRepositoryFactory:
pass
# Should not raise an exception
RepositoryFactory._validate_constructor_signature(
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
@ -114,7 +114,7 @@ class TestRepositoryFactory:
pass
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._validate_constructor_signature(
DifyCoreRepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)
@ -131,7 +131,7 @@ class TestRepositoryFactory:
pass
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@ -153,11 +153,11 @@ class TestRepositoryFactory:
# Mock the validation methods
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(RepositoryFactory, "_validate_constructor_signature"),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = RepositoryFactory.create_workflow_execution_repository(
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
@ -183,7 +183,7 @@ class TestRepositoryFactory:
mock_user = MagicMock(spec=Account)
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_execution_repository(
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -203,15 +203,15 @@ class TestRepositoryFactory:
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
RepositoryFactory,
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_execution_repository(
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -231,12 +231,12 @@ class TestRepositoryFactory:
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(RepositoryFactory, "_validate_constructor_signature"),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_execution_repository(
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -263,11 +263,11 @@ class TestRepositoryFactory:
# Mock the validation methods
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(RepositoryFactory, "_validate_constructor_signature"),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = RepositoryFactory.create_workflow_node_execution_repository(
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id=app_id,
@ -293,7 +293,7 @@ class TestRepositoryFactory:
mock_user = MagicMock(spec=EndUser)
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_node_execution_repository(
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -325,11 +325,11 @@ class TestRepositoryFactory:
# Mock the validation methods
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(RepositoryFactory, "_validate_constructor_signature"),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
result = RepositoryFactory.create_workflow_execution_repository(
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,
app_id="test-app-id",
@ -357,15 +357,15 @@ class TestRepositoryFactory:
# Mock import to succeed but validation to fail
mock_repository_class = MagicMock()
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(
RepositoryFactory,
DifyCoreRepositoryFactory,
"_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"),
),
):
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_node_execution_repository(
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -385,12 +385,12 @@ class TestRepositoryFactory:
# Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with (
patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(RepositoryFactory, "_validate_constructor_signature"),
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
):
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory.create_workflow_node_execution_repository(
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
app_id="test-app-id",
@ -424,7 +424,7 @@ class TestRepositoryFactory:
pass
# Should not raise an exception (private methods are ignored)
RepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass)."""
@ -434,7 +434,7 @@ class TestRepositoryFactory:
pass
# Should not raise an exception (extra parameters are allowed)
RepositoryFactory._validate_constructor_signature(
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
@ -447,7 +447,7 @@ class TestRepositoryFactory:
# Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info:
RepositoryFactory._validate_constructor_signature(
DifyCoreRepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
)
assert "does not accept required parameters" in str(exc_info.value)

@ -0,0 +1,278 @@
from datetime import datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
@pytest.fixture
def repository(self):
mock_session_maker = MagicMock()
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
@pytest.fixture
def mock_execution(self):
execution = MagicMock(spec=WorkflowNodeExecutionModel)
execution.id = str(uuid4())
execution.tenant_id = "tenant-123"
execution.app_id = "app-456"
execution.workflow_id = "workflow-789"
execution.workflow_run_id = "run-101"
execution.node_id = "node-202"
execution.index = 1
execution.created_at = "2023-01-01T00:00:00Z"
return execution
def test_get_node_last_execution_found(self, repository, mock_execution):
"""Test getting the last execution for a node when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_node_last_execution(
tenant_id="tenant-123",
app_id="app-456",
workflow_id="workflow-789",
node_id="node-202",
)
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalars.return_value.all.return_value = []
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == []
mock_session.execute.assert_called_once()
def test_get_execution_by_id_found(self, repository, mock_execution):
"""Test getting execution by ID when it exists."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = mock_execution
# Act
result = repository.get_execution_by_id(mock_execution.id)
# Assert
assert result == mock_execution
mock_session.scalar.assert_called_once()
def test_get_execution_by_id_not_found(self, repository):
"""Test getting execution by ID when it doesn't exist."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
mock_session.scalar.return_value = None
# Act
result = repository.get_execution_by_id("non-existent-id")
# Assert
assert result is None
mock_session.scalar.assert_called_once()
def test_repository_implements_protocol(self, repository):
"""Test that the repository implements the required protocol methods."""
# Verify all protocol methods are implemented
assert hasattr(repository, "get_node_last_execution")
assert hasattr(repository, "get_executions_by_workflow_run")
assert hasattr(repository, "get_execution_by_id")
# Verify methods are callable
assert callable(repository.get_node_last_execution)
assert callable(repository.get_executions_by_workflow_run)
assert callable(repository.get_execution_by_id)
assert callable(repository.delete_expired_executions)
assert callable(repository.delete_executions_by_app)
assert callable(repository.get_expired_executions_batch)
assert callable(repository.delete_executions_by_ids)
def test_delete_expired_executions(self, repository):
"""Test deleting expired executions."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
before_date = datetime(2023, 1, 1)
# Act
result = repository.delete_expired_executions(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_app(self, repository):
"""Test deleting executions by app."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the select query to return some IDs first time, then empty to stop loop
execution_ids = ["id1", "id2"]
mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 2
# Act
result = repository.delete_executions_by_app(
tenant_id="tenant-123",
app_id="app-456",
batch_size=1000,
)
# Assert
assert result == 2
mock_session.execute.assert_called_once() # One select call
mock_session.query.assert_called_once()
mock_session.commit.assert_called_once()
def test_get_expired_executions_batch(self, repository):
"""Test getting expired executions batch for backup."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Create mock execution objects
mock_execution1 = MagicMock()
mock_execution1.id = "exec-1"
mock_execution2 = MagicMock()
mock_execution2.id = "exec-2"
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
before_date = datetime(2023, 1, 1)
# Act
result = repository.get_expired_executions_batch(
tenant_id="tenant-123",
before_date=before_date,
batch_size=1000,
)
# Assert
assert len(result) == 2
assert result[0].id == "exec-1"
assert result[1].id == "exec-2"
mock_session.execute.assert_called_once()
def test_delete_executions_by_ids(self, repository):
"""Test deleting executions by IDs."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Mock the delete query
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value.delete.return_value = 3
execution_ids = ["id1", "id2", "id3"]
# Act
result = repository.delete_executions_by_ids(execution_ids)
# Assert
assert result == 3
mock_session.query.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_executions_by_ids_empty_list(self, repository):
"""Test deleting executions with empty ID list."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
# Act
result = repository.delete_executions_by_ids([])
# Assert
assert result == 0
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()
Loading…
Cancel
Save