From b2b40492793842ab116e387d5fe8ab055bdad575 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 27 Jun 2025 11:17:34 +0800 Subject: [PATCH] feat: Create a DifyAPIRepositoryFactory to handle workflow node execution operations out of core. Signed-off-by: -LAN- --- api/configs/feature/__init__.py | 10 +- .../app/apps/advanced_chat/app_generator.py | 14 +- api/core/app/apps/workflow/app_generator.py | 14 +- api/core/ops/langfuse_trace/langfuse_trace.py | 4 +- .../ops/langsmith_trace/langsmith_trace.py | 4 +- api/core/ops/opik_trace/opik_trace.py | 4 +- api/core/ops/weave_trace/weave_trace.py | 4 +- api/core/repositories/__init__.py | 4 +- api/core/repositories/factory.py | 6 +- api/repositories/__init__.py | 0 .../api_workflow_node_execution_repository.py | 196 ++++++++++++ api/repositories/factory.py | 66 ++++ ..._api_workflow_node_execution_repository.py | 286 ++++++++++++++++++ .../clear_free_plan_tenant_expired_logs.py | 81 ++--- .../workflow_draft_variable_service.py | 15 +- api/services/workflow_run_service.py | 28 +- api/services/workflow_service.py | 53 ++-- api/tasks/remove_app_and_related_data_task.py | 24 +- .../core/repositories/test_factory.py | 82 ++--- ...kflow_node_execution_service_repository.py | 278 +++++++++++++++++ 20 files changed, 1016 insertions(+), 157 deletions(-) create mode 100644 api/repositories/__init__.py create mode 100644 api/repositories/api_workflow_node_execution_repository.py create mode 100644 api/repositories/factory.py create mode 100644 api/repositories/sqlalchemy_api_workflow_node_execution_repository.py create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index cca9f252c9..7e1ada988c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -542,16 +542,22 @@ class RepositoryConfig(BaseSettings): Configuration for repository implementations """ - WORKFLOW_EXECUTION_REPOSITORY: str = Field( + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( description="Repository implementation for WorkflowExecution. Specify as a module path", default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", ) - WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( description="Repository implementation for WorkflowNodeExecution. Specify as a module path", default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", ) + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + class AuthConfig(BaseSettings): """ diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 649a7172d4..4b8f5ebe27 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -182,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -259,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -342,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index f1203dfa4a..2f9632e97d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,7 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -155,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -305,14 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -387,14 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index d0b228f4ba..4a7e66d27c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,7 +123,7 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f3f08d74b8..8a559c4929 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,7 +145,7 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 31ce6fe6c8..fcbbc70fc3 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,7 +160,7 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 95cb0dd621..445c6a8741 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,7 +144,7 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index bb5b3224ff..052ba1c2cb 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,11 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ -from core.repositories.factory import RepositoryFactory, RepositoryImportError +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ - "RepositoryFactory", + "DifyCoreRepositoryFactory", "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 646b587244..a3eaa0ee5c 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -28,7 +28,7 @@ class RepositoryImportError(Exception): pass -class RepositoryFactory: +class DifyCoreRepositoryFactory: """ Factory for creating repository instances based on configuration. @@ -143,7 +143,7 @@ class RepositoryFactory: Raises: RepositoryImportError: If the configured repository cannot be created """ - class_path = dify_config.WORKFLOW_EXECUTION_REPOSITORY + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") try: @@ -189,7 +189,7 @@ class RepositoryFactory: Raises: RepositoryImportError: If the configured repository cannot be created """ - class_path = dify_config.WORKFLOW_NODE_EXECUTION_REPOSITORY + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") try: diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..9e36e0ef7e --- /dev/null +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -0,0 +1,196 @@ +""" +Service-layer repository protocol for WorkflowNodeExecutionModel operations. + +This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel +that abstracts database queries currently done directly in service classes. This repository is +specifically designed for service-layer needs and is separate from the core domain repository. + +The service repository handles operations that require access to database-specific fields like +tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. +""" + +from abc import abstractmethod +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from models.workflow import WorkflowNodeExecutionModel + + +class DifyAPIWorkflowNodeExecutionRepository(Protocol): + """ + Protocol for service-layer operations on WorkflowNodeExecutionModel. + + This repository provides database access patterns specifically needed by service classes, + handling queries that involve database-specific fields and multi-tenancy concerns. + + Key responsibilities: + - Manages database operations for workflow node executions + - Handles multi-tenant data isolation + - Provides batch processing capabilities + - Supports execution lifecycle management + + Implementation notes: + - Returns database models directly (WorkflowNodeExecutionModel) + - Handles tenant/app filtering automatically + - Provides service-specific query patterns + - Focuses on database operations without domain logic + - Supports cleanup and maintenance operations + """ + + @abstractmethod + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method finds the latest execution of a specific node within a workflow, + ordered by creation time. Used primarily for debugging and inspection purposes. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + ... + + @abstractmethod + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method retrieves all node executions that belong to a specific workflow run, + ordered by index in descending order for proper trace visualization. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + ... + + @abstractmethod + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method retrieves a specific execution by its unique identifier. + Tenant filtering is optional for cases where the execution ID is globally unique. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + ... + + @abstractmethod + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + This method is used for cleanup operations to remove expired executions + in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + ... + + @abstractmethod + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + This method is used when removing an app and all its related data. + Executions are deleted in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + ... + + @abstractmethod + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + This method retrieves expired executions without deleting them, + allowing the caller to backup the data before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + ... + + @abstractmethod + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + This method deletes specific executions by their IDs, + typically used after backing up the data. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + ... diff --git a/api/repositories/factory.py b/api/repositories/factory.py new file mode 100644 index 0000000000..fe0bf17441 --- /dev/null +++ b/api/repositories/factory.py @@ -0,0 +1,66 @@ +""" +DifyAPI Repository Factory for creating repository instances. + +This factory is specifically designed for DifyAPI repositories that handle +service-layer operations with dependency injection patterns. +""" + +import logging + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + + +class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): + """ + Factory for creating DifyAPI repository instances based on configuration. + + This factory handles the creation of repositories that are specifically designed + for service-layer operations and use dependency injection with sessionmaker + for better testability and separation of concerns. + """ + + @classmethod + def create_api_workflow_node_execution_repository( + cls, session_maker: sessionmaker + ) -> DifyAPIWorkflowNodeExecutionRepository: + """ + Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration. + + This repository is designed for service-layer operations and uses dependency injection + with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes, handling queries + that involve database-specific fields and multi-tenancy concerns. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured DifyAPIWorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..ccde8b8076 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,286 @@ +""" +SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository. + +This module provides a concrete implementation of the service repository protocol +using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +from sqlalchemy import delete, desc, select +from sqlalchemy.orm import Session, sessionmaker + +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + + +class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. + + This repository provides service-layer database operations for WorkflowNodeExecutionModel + using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository + protocol with the following features: + + - Multi-tenancy data isolation through tenant_id filtering + - Direct database model operations without domain conversion + - Batch processing for efficient large-scale operations + - Optimized query patterns for common access patterns + - Dependency injection for better testability and maintainability + - Session management and transaction handling with proper cleanup + - Maintenance operations for data lifecycle management + - Thread-safe database operations using session-per-request pattern + """ + + def __init__(self, session_maker: sessionmaker[Session]): + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for creating database sessions + """ + self._session_maker = session_maker + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method replicates the query pattern from WorkflowService.get_node_last_run() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.created_at)) + .limit(1) + ) + + with self._session_maker() as session: + return session.scalar(stmt) + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method replicates the query pattern from WorkflowDraftVariableService + and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id) + + # Add tenant filtering if provided + if tenant_id is not None: + stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id) + + with self._session_maker() as session: + return session.scalar(stmt) + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + if not execution_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(stmt) + session.commit() + return result.rowcount diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d581..cb0115cb5a 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder @@ -14,7 +14,8 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -105,48 +106,52 @@ class ClearFreePlanTenantExpiredLogs: ) ) - while True: - with Session(db.engine).no_autoflush as session: - workflow_node_executions = ( - session.query(WorkflowNodeExecutionModel) - .filter( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.created_at - < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) + # Process expired workflow node executions with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 - if len(workflow_node_executions) == 0: - break + while True: + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) - # save workflow node executions - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder(workflow_node_executions), - ).encode("utf-8"), - ) + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) - workflow_node_execution_ids = [ - workflow_node_execution.id for workflow_node_execution in workflow_node_executions - ] + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] - # delete workflow node executions - session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), - ).delete(synchronize_session=False) - session.commit() + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count - click.echo( - click.style( - f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" - f" workflow node executions for tenant {tenant_id}" - ) + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" ) + ) + + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break while True: with Session(db.engine).no_autoflush as session: diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e4..0cb8c5574b 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom @@ -21,11 +21,13 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables from core.workflow.variable_loader import VariableLoader +from extensions.ext_database import db from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from repositories.factory import DifyAPIRepositoryFactory _logger = logging.getLogger(__name__) @@ -118,6 +120,10 @@ class WorkflowDraftVariableService: def __init__(self, session: Session) -> None: self._session = session + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() @@ -248,8 +254,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) - node_exec = self._session.scalars(query).first() + node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 4acf1206b1..7c57c88317 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,7 +2,7 @@ import threading from collections.abc import Sequence from typing import Optional -from sqlalchemy import desc, select +from sqlalchemy.orm import sessionmaker import contexts from extensions.ext_database import db @@ -15,9 +15,17 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: + def __init__(self): + """Initialize WorkflowRunService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -138,17 +146,11 @@ class WorkflowRunService: # Get tenant_id from user tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") - # Use SQLAlchemy 2.0 style query directly - stmt = ( - select(WorkflowNodeExecutionModel) - .where( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_run_id == run_id, - ) - .order_by(desc(WorkflowNodeExecutionModel.index)) + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, + app_id=app_model.id, + workflow_run_id=run_id, ) - - workflow_node_executions = db.session.execute(stmt).scalars().all() - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e38858f73e..8122505592 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -7,13 +7,13 @@ from typing import Any, Optional from uuid import uuid4 from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -41,6 +41,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -57,21 +58,31 @@ class WorkflowService: Workflow Service """ - def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: - # TODO(QuantumGhost): This query is not fully covered by index. - criteria = ( - WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_id == workflow.id, - WorkflowNodeExecutionModel.node_id == node_id, + def __init__(self): + """Initialize WorkflowService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker ) - node_exec = ( - db.session.query(WorkflowNodeExecutionModel) - .filter(*criteria) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .first() + + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + """ + Get the most recent execution for a specific node. + + Args: + app_model: The application model + workflow: The workflow model + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow.id, + node_id=node_id, ) - return node_exec def is_workflow_exist(self, app_model: App) -> bool: return ( @@ -396,7 +407,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = RepositoryFactory.create_workflow_node_execution_repository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -404,8 +415,9 @@ class WorkflowService: ) repository.save(node_execution) - stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == node_execution.id) - workflow_node_execution = db.session.execute(stmt).scalar_one() + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -418,6 +430,7 @@ class WorkflowService: ) draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) session.commit() + return workflow_node_execution def run_free_workflow_node( @@ -429,7 +442,7 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( + node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, @@ -441,7 +454,7 @@ class WorkflowService: node_id=node_id, ) - return workflow_node_execution + return node_execution def _handle_node_run_result( self, diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4a62cb74b4..9d781d9364 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import click from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models import ( @@ -31,7 +32,8 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -201,18 +203,18 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id == workflow_node_execution_id - ).delete(synchronize_session=False) - - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + """Delete all workflow node executions for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + deleted_count = node_execution_repo.delete_executions_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 1d52c5daf8..fce4a6fb6b 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -12,7 +12,7 @@ from pytest_mock import MockerFixture from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import RepositoryFactory, RepositoryImportError +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models import Account, EndUser @@ -27,25 +27,25 @@ class TestRepositoryFactory: """Test successful class import.""" # Test importing a real class class_path = "unittest.mock.MagicMock" - result = RepositoryFactory._import_class(class_path) + result = DifyCoreRepositoryFactory._import_class(class_path) assert result is MagicMock def test_import_class_invalid_path(self): """Test import with invalid module path.""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("invalid.module.path") + DifyCoreRepositoryFactory._import_class("invalid.module.path") assert "Cannot import repository class" in str(exc_info.value) def test_import_class_invalid_class_name(self): """Test import with invalid class name.""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("unittest.mock.NonExistentClass") + DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") assert "Cannot import repository class" in str(exc_info.value) def test_import_class_malformed_path(self): """Test import with malformed path (no dots).""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("invalidpath") + DifyCoreRepositoryFactory._import_class("invalidpath") assert "Cannot import repository class" in str(exc_info.value) def test_validate_repository_interface_success(self): @@ -68,7 +68,7 @@ class TestRepositoryFactory: pass # Should not raise an exception - RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) def test_validate_repository_interface_missing_methods(self): """Test interface validation with missing methods.""" @@ -89,7 +89,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) assert "does not implement required methods" in str(exc_info.value) assert "get_by_id" in str(exc_info.value) @@ -101,7 +101,7 @@ class TestRepositoryFactory: pass # Should not raise an exception - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) @@ -114,7 +114,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] ) assert "does not accept required parameters" in str(exc_info.value) @@ -131,7 +131,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) + DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) assert "Failed to validate constructor signature" in str(exc_info.value) @patch("core.repositories.factory.dify_config") @@ -153,11 +153,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id=app_id, @@ -183,7 +183,7 @@ class TestRepositoryFactory: mock_user = MagicMock(spec=Account) with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -203,15 +203,15 @@ class TestRepositoryFactory: # Mock import to succeed but validation to fail mock_repository_class = MagicMock() with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object( - RepositoryFactory, + DifyCoreRepositoryFactory, "_validate_repository_interface", side_effect=RepositoryImportError("Interface validation failed"), ), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -231,12 +231,12 @@ class TestRepositoryFactory: # Mock import and validation to succeed but instantiation to fail mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -263,11 +263,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_node_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id=app_id, @@ -293,7 +293,7 @@ class TestRepositoryFactory: mock_user = MagicMock(spec=EndUser) with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -325,11 +325,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_engine, # Using Engine instead of sessionmaker user=mock_user, app_id="test-app-id", @@ -357,15 +357,15 @@ class TestRepositoryFactory: # Mock import to succeed but validation to fail mock_repository_class = MagicMock() with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object( - RepositoryFactory, + DifyCoreRepositoryFactory, "_validate_repository_interface", side_effect=RepositoryImportError("Interface validation failed"), ), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -385,12 +385,12 @@ class TestRepositoryFactory: # Mock import and validation to succeed but instantiation to fail mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -424,7 +424,7 @@ class TestRepositoryFactory: pass # Should not raise an exception (private methods are ignored) - RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) def test_validate_constructor_signature_with_extra_params(self): """Test constructor validation with extra parameters (should pass).""" @@ -434,7 +434,7 @@ class TestRepositoryFactory: pass # Should not raise an exception (extra parameters are allowed) - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) @@ -447,7 +447,7 @@ class TestRepositoryFactory: # Current implementation doesn't handle **kwargs, so this should raise an exception with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) assert "does not accept required parameters" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py new file mode 100644 index 0000000000..96f9139804 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -0,0 +1,278 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowNodeExecutionModel +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) + + +class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: + @pytest.fixture + def repository(self): + mock_session_maker = MagicMock() + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) + + @pytest.fixture + def mock_execution(self): + execution = MagicMock(spec=WorkflowNodeExecutionModel) + execution.id = str(uuid4()) + execution.tenant_id = "tenant-123" + execution.app_id = "app-456" + execution.workflow_id = "workflow-789" + execution.workflow_run_id = "run-101" + execution.node_id = "node-202" + execution.index = 1 + execution.created_at = "2023-01-01T00:00:00Z" + return execution + + def test_get_node_last_execution_found(self, repository, mock_execution): + """Test getting the last execution for a node when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.scalar.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_node_last_execution_not_found(self, repository): + """Test getting the last execution for a node when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_get_executions_by_workflow_run(self, repository, mock_execution): + """Test getting all executions for a workflow run.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + executions = [mock_execution] + mock_session.execute.return_value.scalars.return_value.all.return_value = executions + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == executions + mock_session.execute.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.execute.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_executions_by_workflow_run_empty(self, repository): + """Test getting executions for a workflow run when none exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalars.return_value.all.return_value = [] + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == [] + mock_session.execute.assert_called_once() + + def test_get_execution_by_id_found(self, repository, mock_execution): + """Test getting execution by ID when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_execution_by_id(mock_execution.id) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + + def test_get_execution_by_id_not_found(self, repository): + """Test getting execution by ID when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_execution_by_id("non-existent-id") + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_repository_implements_protocol(self, repository): + """Test that the repository implements the required protocol methods.""" + # Verify all protocol methods are implemented + assert hasattr(repository, "get_node_last_execution") + assert hasattr(repository, "get_executions_by_workflow_run") + assert hasattr(repository, "get_execution_by_id") + + # Verify methods are callable + assert callable(repository.get_node_last_execution) + assert callable(repository.get_executions_by_workflow_run) + assert callable(repository.get_execution_by_id) + assert callable(repository.delete_expired_executions) + assert callable(repository.delete_executions_by_app) + assert callable(repository.get_expired_executions_batch) + assert callable(repository.delete_executions_by_ids) + + def test_delete_expired_executions(self, repository): + """Test deleting expired executions.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] # Less than batch_size to trigger break + mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids + + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 2 + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.delete_expired_executions( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert result == 2 + mock_session.execute.assert_called_once() # One select call + mock_session.query.assert_called_once() + mock_session.commit.assert_called_once() + + def test_delete_executions_by_app(self, repository): + """Test deleting executions by app.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] + mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids + + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 2 + + # Act + result = repository.delete_executions_by_app( + tenant_id="tenant-123", + app_id="app-456", + batch_size=1000, + ) + + # Assert + assert result == 2 + mock_session.execute.assert_called_once() # One select call + mock_session.query.assert_called_once() + mock_session.commit.assert_called_once() + + def test_get_expired_executions_batch(self, repository): + """Test getting expired executions batch for backup.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Create mock execution objects + mock_execution1 = MagicMock() + mock_execution1.id = "exec-1" + mock_execution2 = MagicMock() + mock_execution2.id = "exec-2" + + mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2] + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.get_expired_executions_batch( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert len(result) == 2 + assert result[0].id == "exec-1" + assert result[1].id == "exec-2" + mock_session.execute.assert_called_once() + + def test_delete_executions_by_ids(self, repository): + """Test deleting executions by IDs.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 3 + + execution_ids = ["id1", "id2", "id3"] + + # Act + result = repository.delete_executions_by_ids(execution_ids) + + # Assert + assert result == 3 + mock_session.query.assert_called_once() + mock_session.commit.assert_called_once() + + def test_delete_executions_by_ids_empty_list(self, repository): + """Test deleting executions with empty ID list.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Act + result = repository.delete_executions_by_ids([]) + + # Assert + assert result == 0 + mock_session.query.assert_not_called() + mock_session.commit.assert_not_called()