diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 86d1879462..38ba3d8994 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -551,16 +551,22 @@ class RepositoryConfig(BaseSettings): Configuration for repository implementations """ - WORKFLOW_EXECUTION_REPOSITORY: str = Field( + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( description="Repository implementation for WorkflowExecution. Specify as a module path", default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", ) - WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( description="Repository implementation for WorkflowNodeExecution. Specify as a module path", default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", ) + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + class AuthConfig(BaseSettings): """ diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 649a7172d4..4b8f5ebe27 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -182,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -259,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -342,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index f1203dfa4a..2f9632e97d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,7 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -155,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -305,14 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -387,14 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index d0b228f4ba..4a7e66d27c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,7 +123,7 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f3f08d74b8..8a559c4929 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,7 +145,7 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index c0c6764b9a..be4997a5bf 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,7 +160,7 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 95cb0dd621..445c6a8741 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,7 +144,7 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, app_id=app_id, diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index bb5b3224ff..052ba1c2cb 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,11 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ -from core.repositories.factory import RepositoryFactory, RepositoryImportError +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ - "RepositoryFactory", + "DifyCoreRepositoryFactory", "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 646b587244..a3eaa0ee5c 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -28,7 +28,7 @@ class RepositoryImportError(Exception): pass -class RepositoryFactory: +class DifyCoreRepositoryFactory: """ Factory for creating repository instances based on configuration. @@ -143,7 +143,7 @@ class RepositoryFactory: Raises: RepositoryImportError: If the configured repository cannot be created """ - class_path = dify_config.WORKFLOW_EXECUTION_REPOSITORY + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") try: @@ -189,7 +189,7 @@ class RepositoryFactory: Raises: RepositoryImportError: If the configured repository cannot be created """ - class_path = dify_config.WORKFLOW_NODE_EXECUTION_REPOSITORY + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") try: diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 00a2d1f87d..9e36e0ef7e 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -9,15 +9,15 @@ The service repository handles operations that require access to database-specif tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. """ +from abc import abstractmethod from collections.abc import Sequence from datetime import datetime from typing import Optional, Protocol -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel -class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): +class DifyAPIWorkflowNodeExecutionRepository(Protocol): """ Protocol for service-layer operations on WorkflowNodeExecutionModel. @@ -38,6 +38,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr - Supports cleanup and maintenance operations """ + @abstractmethod def get_node_last_execution( self, tenant_id: str, @@ -62,6 +63,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def get_executions_by_workflow_run( self, tenant_id: str, @@ -84,6 +86,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def get_execution_by_id( self, execution_id: str, @@ -95,10 +98,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr This method retrieves a specific execution by its unique identifier. Tenant filtering is optional for cases where the execution ID is globally unique. - When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. - If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should - set `tenant_id` to prevent horizontal privilege escalation. - Args: execution_id: The execution identifier tenant_id: Optional tenant identifier for additional filtering @@ -108,6 +107,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def delete_expired_executions( self, tenant_id: str, @@ -130,6 +130,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def delete_executions_by_app( self, tenant_id: str, @@ -152,6 +153,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def get_expired_executions_batch( self, tenant_id: str, @@ -174,6 +176,7 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr """ ... + @abstractmethod def delete_executions_by_ids( self, execution_ids: Sequence[str], @@ -184,10 +187,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr This method deletes specific executions by their IDs, typically used after backing up the data. - This method does not perform tenant isolation checks. The caller is responsible for ensuring proper - data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests), - additional tenant validation should be implemented to prevent unauthorized access. - Args: execution_ids: List of execution IDs to delete diff --git a/api/repositories/factory.py b/api/repositories/factory.py index 0a0adbf2c2..fe0bf17441 100644 --- a/api/repositories/factory.py +++ b/api/repositories/factory.py @@ -12,7 +12,6 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository -from repositories.api_workflow_run_repository import APIWorkflowRunRepository logger = logging.getLogger(__name__) @@ -65,39 +64,3 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): raise RepositoryImportError( f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" ) from e - - @classmethod - def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository: - """ - Create an APIWorkflowRunRepository instance based on configuration. - - This repository is designed for service-layer WorkflowRun operations and uses dependency - injection with a sessionmaker for better testability and separation of concerns. It provides - database access patterns specifically needed by service classes for workflow run management, - including pagination, filtering, and bulk operations. - - Args: - session_maker: SQLAlchemy sessionmaker to inject for database session management. - - Returns: - Configured APIWorkflowRunRepository instance - - Raises: - RepositoryImportError: If the configured repository cannot be imported or instantiated - """ - class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY - logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") - - try: - repository_class = cls._import_class(class_path) - cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) - # Service repository requires session_maker parameter - cls._validate_constructor_signature(repository_class, ["session_maker"]) - - return repository_class(session_maker=session_maker) # type: ignore[no-any-return] - except RepositoryImportError: - # Re-raise our custom errors as-is - raise - except Exception as e: - logger.exception("Failed to create APIWorkflowRunRepository") - raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index e6a23ddf9f..ccde8b8076 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -124,10 +124,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut This method replicates the query pattern from WorkflowDraftVariableService and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. - When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. - If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should - set `tenant_id` to prevent horizontal privilege escalation. - Args: execution_id: The execution identifier tenant_id: Optional tenant identifier for additional filtering diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index ddd16b2e0c..cb088304c6 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -14,6 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message +from models.workflow import WorkflowRun from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService @@ -111,6 +112,47 @@ class ClearFreePlanTenantExpiredLogs: before_date = datetime.datetime.now() - datetime.timedelta(days=days) total_deleted = 0 + while True: + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) + + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) + + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] + + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count + + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" + ) + ) + + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break + while True: # Get a batch of expired executions for backup workflow_node_executions = node_execution_repo.get_expired_executions_batch( diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f306e1f062..bec69f399d 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -21,6 +21,7 @@ from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables from core.workflow.variable_loader import VariableLoader +from extensions.ext_database import db from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation @@ -129,11 +130,8 @@ class WorkflowDraftVariableService: AssertionError: If the provided session is not bound to an `Engine` object. """ self._session = session - engine = session.get_bind() - # Ensure the session is bound to a engine. - assert isinstance(engine, Engine) - session_maker = sessionmaker(bind=engine, expire_on_commit=False) - self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -266,7 +264,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) + node_exec = self._node_execution_service_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 428aef4007..3ec3dc193a 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,7 +2,7 @@ import threading from collections.abc import Sequence from typing import Optional -from sqlalchemy import desc, select +from sqlalchemy.orm import sessionmaker import contexts from extensions.ext_database import db @@ -15,6 +15,7 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: @@ -24,7 +25,6 @@ class WorkflowRunService: self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) - self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ @@ -111,17 +111,11 @@ class WorkflowRunService: # Get tenant_id from user tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") - # Use SQLAlchemy 2.0 style query directly - stmt = ( - select(WorkflowNodeExecutionModel) - .where( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_run_id == run_id, - ) - .order_by(desc(WorkflowNodeExecutionModel.index)) + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, + app_id=app_model.id, + workflow_run_id=run_id, ) - - workflow_node_executions = db.session.execute(stmt).scalars().all() - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 393bf646e6..464e42ff98 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,7 +13,7 @@ from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import RepositoryFactory +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult @@ -59,10 +59,9 @@ class WorkflowService: Workflow Service """ - def __init__(self, session_maker: sessionmaker | None = None): + def __init__(self): """Initialize WorkflowService with repository dependencies.""" - if session_maker is None: - session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( session_maker ) @@ -409,7 +408,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = RepositoryFactory.create_workflow_node_execution_repository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -417,8 +416,9 @@ class WorkflowService: ) repository.save(node_execution) - stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == node_execution.id) - workflow_node_execution = db.session.execute(stmt).scalar_one() + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 179adcbd6e..7f8c6c26a0 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -32,7 +32,7 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun from repositories.factory import DifyAPIRepositoryFactory @@ -206,7 +206,7 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): """Delete all workflow node executions for an app using the service repository.""" - session_maker = sessionmaker(bind=db.engine) + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) deleted_count = node_execution_repo.delete_executions_by_app( diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 1d52c5daf8..fce4a6fb6b 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -12,7 +12,7 @@ from pytest_mock import MockerFixture from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import RepositoryFactory, RepositoryImportError +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models import Account, EndUser @@ -27,25 +27,25 @@ class TestRepositoryFactory: """Test successful class import.""" # Test importing a real class class_path = "unittest.mock.MagicMock" - result = RepositoryFactory._import_class(class_path) + result = DifyCoreRepositoryFactory._import_class(class_path) assert result is MagicMock def test_import_class_invalid_path(self): """Test import with invalid module path.""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("invalid.module.path") + DifyCoreRepositoryFactory._import_class("invalid.module.path") assert "Cannot import repository class" in str(exc_info.value) def test_import_class_invalid_class_name(self): """Test import with invalid class name.""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("unittest.mock.NonExistentClass") + DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") assert "Cannot import repository class" in str(exc_info.value) def test_import_class_malformed_path(self): """Test import with malformed path (no dots).""" with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._import_class("invalidpath") + DifyCoreRepositoryFactory._import_class("invalidpath") assert "Cannot import repository class" in str(exc_info.value) def test_validate_repository_interface_success(self): @@ -68,7 +68,7 @@ class TestRepositoryFactory: pass # Should not raise an exception - RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) def test_validate_repository_interface_missing_methods(self): """Test interface validation with missing methods.""" @@ -89,7 +89,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) assert "does not implement required methods" in str(exc_info.value) assert "get_by_id" in str(exc_info.value) @@ -101,7 +101,7 @@ class TestRepositoryFactory: pass # Should not raise an exception - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) @@ -114,7 +114,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] ) assert "does not accept required parameters" in str(exc_info.value) @@ -131,7 +131,7 @@ class TestRepositoryFactory: pass with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) + DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) assert "Failed to validate constructor signature" in str(exc_info.value) @patch("core.repositories.factory.dify_config") @@ -153,11 +153,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id=app_id, @@ -183,7 +183,7 @@ class TestRepositoryFactory: mock_user = MagicMock(spec=Account) with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -203,15 +203,15 @@ class TestRepositoryFactory: # Mock import to succeed but validation to fail mock_repository_class = MagicMock() with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object( - RepositoryFactory, + DifyCoreRepositoryFactory, "_validate_repository_interface", side_effect=RepositoryImportError("Interface validation failed"), ), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -231,12 +231,12 @@ class TestRepositoryFactory: # Mock import and validation to succeed but instantiation to fail mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_execution_repository( + DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -263,11 +263,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_node_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id=app_id, @@ -293,7 +293,7 @@ class TestRepositoryFactory: mock_user = MagicMock(spec=EndUser) with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -325,11 +325,11 @@ class TestRepositoryFactory: # Mock the validation methods with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): - result = RepositoryFactory.create_workflow_execution_repository( + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=mock_engine, # Using Engine instead of sessionmaker user=mock_user, app_id="test-app-id", @@ -357,15 +357,15 @@ class TestRepositoryFactory: # Mock import to succeed but validation to fail mock_repository_class = MagicMock() with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object( - RepositoryFactory, + DifyCoreRepositoryFactory, "_validate_repository_interface", side_effect=RepositoryImportError("Interface validation failed"), ), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -385,12 +385,12 @@ class TestRepositoryFactory: # Mock import and validation to succeed but instantiation to fail mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) with ( - patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), - patch.object(RepositoryFactory, "_validate_repository_interface"), - patch.object(RepositoryFactory, "_validate_constructor_signature"), + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), ): with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory.create_workflow_node_execution_repository( + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=mock_session_factory, user=mock_user, app_id="test-app-id", @@ -424,7 +424,7 @@ class TestRepositoryFactory: pass # Should not raise an exception (private methods are ignored) - RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) def test_validate_constructor_signature_with_extra_params(self): """Test constructor validation with extra parameters (should pass).""" @@ -434,7 +434,7 @@ class TestRepositoryFactory: pass # Should not raise an exception (extra parameters are allowed) - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) @@ -447,7 +447,7 @@ class TestRepositoryFactory: # Current implementation doesn't handle **kwargs, so this should raise an exception with pytest.raises(RepositoryImportError) as exc_info: - RepositoryFactory._validate_constructor_signature( + DifyCoreRepositoryFactory._validate_constructor_signature( MockRepository, ["session_factory", "user", "app_id", "triggered_from"] ) assert "does not accept required parameters" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py index 32d2f8b7e0..96f9139804 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -163,18 +163,12 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] # Less than batch_size to trigger break + mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock execute method to handle both select and delete statements - def mock_execute(stmt): - mock_result = MagicMock() - # For select statements, return execution IDs - if hasattr(stmt, "limit"): # This is our select statement - mock_result.scalars.return_value.all.return_value = execution_ids - else: # This is our delete statement - mock_result.rowcount = 2 - return mock_result - - mock_session.execute.side_effect = mock_execute + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 2 before_date = datetime(2023, 1, 1) @@ -187,7 +181,8 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.execute.assert_called_once() # One select call + mock_session.query.assert_called_once() mock_session.commit.assert_called_once() def test_delete_executions_by_app(self, repository): @@ -198,18 +193,12 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Mock the select query to return some IDs first time, then empty to stop loop execution_ids = ["id1", "id2"] + mock_session.execute.return_value.scalars.return_value.all.return_value = execution_ids - # Mock execute method to handle both select and delete statements - def mock_execute(stmt): - mock_result = MagicMock() - # For select statements, return execution IDs - if hasattr(stmt, "limit"): # This is our select statement - mock_result.scalars.return_value.all.return_value = execution_ids - else: # This is our delete statement - mock_result.rowcount = 2 - return mock_result - - mock_session.execute.side_effect = mock_execute + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 2 # Act result = repository.delete_executions_by_app( @@ -220,7 +209,8 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 2 - assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.execute.assert_called_once() # One select call + mock_session.query.assert_called_once() mock_session.commit.assert_called_once() def test_get_expired_executions_batch(self, repository): @@ -258,10 +248,10 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: mock_session = MagicMock(spec=Session) repository._session_maker.return_value.__enter__.return_value = mock_session - # Mock the delete query result - mock_result = MagicMock() - mock_result.rowcount = 3 - mock_session.execute.return_value = mock_result + # Mock the delete query + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.filter.return_value.delete.return_value = 3 execution_ids = ["id1", "id2", "id3"] @@ -270,7 +260,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == 3 - mock_session.execute.assert_called_once() + mock_session.query.assert_called_once() mock_session.commit.assert_called_once() def test_delete_executions_by_ids_empty_list(self, repository):