diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f1d529355d..5e5ab974b7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -552,12 +552,16 @@ class RepositoryConfig(BaseSettings): """ CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( - description="Repository implementation for WorkflowExecution. Specify as a module path", + description="Repository implementation for WorkflowExecution. Options: " + "'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), " + "'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'", default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", ) CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( - description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + description="Repository implementation for WorkflowNodeExecution. Options: " + "'core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository' (default), " + "'core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository'", default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", ) @@ -572,6 +576,12 @@ class RepositoryConfig(BaseSettings): default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", ) + # Celery repository configuration + CELERY_REPOSITORY_ASYNC_TIMEOUT: int = Field( + description="Timeout in seconds for Celery repository async operations", + default=30, + ) + class AuthConfig(BaseSettings): """ diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 052ba1c2cb..d83823d7b9 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "CeleryWorkflowExecutionRepository", + "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py new file mode 100644 index 0000000000..070a181102 --- /dev/null +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -0,0 +1,180 @@ +""" +Celery-based implementation of the WorkflowExecutionRepository. + +This implementation uses Celery tasks for asynchronous storage operations, +providing improved performance by offloading database operations to background workers. +""" + +import logging +from typing import Optional, Union + +from celery.result import AsyncResult +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution import WorkflowExecution +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.helper import extract_tenant_id +from models import Account, CreatorUserRole, EndUser +from models.enums import WorkflowRunTriggeredFrom +from tasks.workflow_execution_tasks import ( + save_workflow_execution_task, +) + +logger = logging.getLogger(__name__) + + +class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository): + """ + Celery-based implementation of the WorkflowExecutionRepository interface. + + This implementation provides asynchronous storage capabilities by using Celery tasks + to handle database operations in background workers. This improves performance by + reducing the blocking time for workflow execution storage operations. + + Key features: + - Asynchronous save operations using Celery tasks + - Fallback to synchronous operations for read operations when immediate results are needed + - Support for multi-tenancy through tenant/app filtering + - Automatic retry and error handling through Celery + - Configurable timeouts for async operations + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom], + async_timeout: int = 30, + ): + """ + Initialize the repository with Celery task configuration and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for fallback operations + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + async_timeout: Timeout in seconds for async operations (default: 30) + """ + # Store session factory for fallback operations + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Async operation timeout + self._async_timeout = async_timeout + + # Cache for pending async operations + self._pending_saves: dict[str, AsyncResult] = {} + + logger.info( + f"Initialized CeleryWorkflowExecutionRepository for tenant {self._tenant_id}, " + f"app {self._app_id}, triggered_from {self._triggered_from}" + ) + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution instance asynchronously using Celery. + + This method queues the save operation as a Celery task and returns immediately, + providing improved performance for high-throughput scenarios. + + Args: + execution: The WorkflowExecution instance to save or update + """ + try: + # Serialize execution for Celery task + execution_data = execution.model_dump() + + # Queue the save operation as a Celery task + task_result = save_workflow_execution_task.delay( + execution_data=execution_data, + tenant_id=self._tenant_id, + app_id=self._app_id or "", + triggered_from=self._triggered_from.value if self._triggered_from else "", + creator_user_id=self._creator_user_id, + creator_user_role=self._creator_user_role.value, + ) + + # Store the task result for potential status checking + self._pending_saves[execution.id_] = task_result + + logger.debug(f"Queued async save for workflow execution: {execution.id_}") + + except Exception as e: + logger.exception(f"Failed to queue save operation for execution {execution.id_}: {e}") + # In case of Celery failure, we could implement a fallback to synchronous save + # For now, we'll re-raise the exception + raise + + def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None: + """ + Wait for all pending save operations to complete. + + This method is useful for ensuring data consistency when immediate + persistence is required (e.g., during testing or critical operations). + + Args: + timeout: Maximum time to wait for all operations (uses instance timeout if None) + """ + wait_timeout = timeout or self._async_timeout + + for execution_id, task_result in list(self._pending_saves.items()): + try: + if not task_result.ready(): + logger.debug(f"Waiting for save operation to complete: {execution_id}") + task_result.get(timeout=wait_timeout) + # Remove completed task + del self._pending_saves[execution_id] + except Exception as e: + logger.exception(f"Failed to wait for save operation {execution_id}: {e}") + + def get_pending_save_count(self) -> int: + """ + Get the number of pending save operations. + + Returns: + Number of save operations still in progress + """ + # Clean up completed tasks + completed_ids = [] + for execution_id, task_result in self._pending_saves.items(): + if task_result.ready(): + completed_ids.append(execution_id) + + for execution_id in completed_ids: + del self._pending_saves[execution_id] + + return len(self._pending_saves) + + def clear_pending_saves(self) -> None: + """ + Clear all pending save operations without waiting for completion. + + This method is useful for cleanup operations or when canceling workflows. + """ + self._pending_saves.clear() + logger.debug("Cleared all pending save operations") diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py new file mode 100644 index 0000000000..8c0ce64fd8 --- /dev/null +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -0,0 +1,275 @@ +""" +Celery-based implementation of the WorkflowNodeExecutionRepository. + +This implementation uses Celery tasks for asynchronous storage operations, +providing improved performance by offloading database operations to background workers. +""" + +import logging +from collections.abc import Sequence +from typing import Optional, Union + +from celery.result import AsyncResult +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution +from core.workflow.repositories.workflow_node_execution_repository import ( + OrderConfig, + WorkflowNodeExecutionRepository, +) +from libs.helper import extract_tenant_id +from models import Account, CreatorUserRole, EndUser +from models.workflow import WorkflowNodeExecutionTriggeredFrom +from tasks.workflow_node_execution_tasks import ( + get_workflow_node_executions_by_workflow_run_task, + save_workflow_node_execution_task, +) + +logger = logging.getLogger(__name__) + + +class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): + """ + Celery-based implementation of the WorkflowNodeExecutionRepository interface. + + This implementation provides asynchronous storage capabilities by using Celery tasks + to handle database operations in background workers. This improves performance by + reducing the blocking time for workflow node execution storage operations. + + Key features: + - Asynchronous save operations using Celery tasks + - Fallback to synchronous operations for read operations when immediate results are needed + - Support for multi-tenancy through tenant/app filtering + - Automatic retry and error handling through Celery + - Configurable timeouts for async operations + - Batch operations for improved efficiency + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + async_timeout: int = 30, + ): + """ + Initialize the repository with Celery task configuration and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for fallback operations + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) + async_timeout: Timeout in seconds for async operations (default: 30) + """ + # Store session factory for fallback operations + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id = extract_tenant_id(user) + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Async operation timeout + self._async_timeout = async_timeout + + # Cache for pending async operations + self._pending_saves: dict[str, AsyncResult] = {} + + # Cache for mapping execution IDs to workflow_execution_ids for efficient workflow-specific waiting + self._workflow_execution_mapping: dict[str, str] = {} + + logger.info( + f"Initialized CeleryWorkflowNodeExecutionRepository for tenant {self._tenant_id}, " + f"app {self._app_id}, triggered_from {self._triggered_from}" + ) + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update a WorkflowNodeExecution instance asynchronously using Celery. + + This method queues the save operation as a Celery task and returns immediately, + providing improved performance for high-throughput scenarios. + + Args: + execution: The WorkflowNodeExecution instance to save or update + """ + try: + # Serialize execution for Celery task + execution_data = execution.model_dump() + + # Queue the save operation as a Celery task + task_result = save_workflow_node_execution_task.delay( + execution_data=execution_data, + tenant_id=self._tenant_id, + app_id=self._app_id or "", + triggered_from=self._triggered_from.value if self._triggered_from else "", + creator_user_id=self._creator_user_id, + creator_user_role=self._creator_user_role.value, + ) + + # Store the task result for potential status checking + self._pending_saves[execution.id] = task_result + + # Cache the workflow_execution_id mapping for efficient workflow-specific waiting + self._workflow_execution_mapping[execution.id] = execution.workflow_execution_id + + logger.debug(f"Queued async save for workflow node execution: {execution.id}") + + except Exception as e: + logger.exception(f"Failed to queue save operation for node execution {execution.id}: {e}") + # In case of Celery failure, we could implement a fallback to synchronous save + # For now, we'll re-raise the exception + raise + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + + Returns: + A sequence of WorkflowNodeExecution instances + """ + try: + # Wait for any pending saves that might affect this workflow run + self._wait_for_pending_saves_by_workflow_run(workflow_run_id) + + # Serialize order config for Celery task + if order_config: + order_config_data = {"order_by": order_config.order_by, "order_direction": order_config.order_direction} + else: + order_config_data = None + + # Queue the get operation as a Celery task + task_result = get_workflow_node_executions_by_workflow_run_task.delay( + workflow_run_id=workflow_run_id, + tenant_id=self._tenant_id, + app_id=self._app_id or "", + order_config=order_config_data, + ) + + # Wait for the result (synchronous for read operations) + executions_data = task_result.get(timeout=self._async_timeout) + + result = [] + for execution_data in executions_data: + execution = WorkflowNodeExecution.model_validate(execution_data) + result.append(execution) + + return result + + except Exception as e: + logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}") + # Could implement fallback to direct database access here + return [] + + def _wait_for_pending_saves_by_workflow_run(self, workflow_run_id: str) -> None: + """ + Wait for any pending save operations that might affect the given workflow run. + + This method now uses the cached workflow_execution_id mapping to only wait for + tasks that belong to the specific workflow run, improving efficiency. + + Args: + workflow_run_id: The workflow run ID to check + """ + # Find execution IDs that belong to this workflow run + relevant_execution_ids = [ + execution_id for execution_id, cached_workflow_id in self._workflow_execution_mapping.items() + if cached_workflow_id == workflow_run_id and execution_id in self._pending_saves + ] + + logger.debug(f"Found {len(relevant_execution_ids)} pending saves for workflow run {workflow_run_id}") + + for execution_id in relevant_execution_ids: + task_result = self._pending_saves.get(execution_id) + if task_result and not task_result.ready(): + try: + logger.debug(f"Waiting for pending save to complete before read: {execution_id}") + task_result.get(timeout=self._async_timeout) + except Exception as e: + logger.exception(f"Failed to wait for pending save {execution_id}: {e}") + + # Clean up completed tasks from both caches + if task_result and task_result.ready(): + self._pending_saves.pop(execution_id, None) + self._workflow_execution_mapping.pop(execution_id, None) + + def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None: + """ + Wait for all pending save operations to complete. + + This method is useful for ensuring data consistency when immediate + persistence is required (e.g., during testing or critical operations). + + Args: + timeout: Maximum time to wait for all operations (uses instance timeout if None) + """ + wait_timeout = timeout or self._async_timeout + + for execution_id, task_result in list(self._pending_saves.items()): + try: + if not task_result.ready(): + logger.debug(f"Waiting for save operation to complete: {execution_id}") + task_result.get(timeout=wait_timeout) + # Remove completed task from both caches + del self._pending_saves[execution_id] + self._workflow_execution_mapping.pop(execution_id, None) + except Exception as e: + logger.exception(f"Failed to wait for save operation {execution_id}: {e}") + + def get_pending_save_count(self) -> int: + """ + Get the number of pending save operations. + + Returns: + Number of save operations still in progress + """ + # Clean up completed tasks + completed_ids = [] + for execution_id, task_result in self._pending_saves.items(): + if task_result.ready(): + completed_ids.append(execution_id) + + for execution_id in completed_ids: + del self._pending_saves[execution_id] + self._workflow_execution_mapping.pop(execution_id, None) + + return len(self._pending_saves) + + def clear_pending_saves(self) -> None: + """ + Clear all pending save operations without waiting for completion. + + This method is useful for cleanup operations or when canceling workflows. + """ + self._pending_saves.clear() + self._workflow_execution_mapping.clear() + logger.debug("Cleared all pending save operations and workflow execution mappings") diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 4118aa61c7..0892878c1a 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -158,16 +158,30 @@ class DifyCoreRepositoryFactory: try: repository_class = cls._import_class(class_path) cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) - cls._validate_constructor_signature( - repository_class, ["session_factory", "user", "app_id", "triggered_from"] - ) - return repository_class( # type: ignore[no-any-return] - session_factory=session_factory, - user=user, - app_id=app_id, - triggered_from=triggered_from, - ) + # Check if this is a Celery repository that needs async_timeout + is_celery_repo = "celery" in class_path.lower() + if is_celery_repo: + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from", "async_timeout"] + ) + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + async_timeout=dify_config.CELERY_REPOSITORY_ASYNC_TIMEOUT, + ) + else: + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) except RepositoryImportError: # Re-raise our custom errors as-is raise @@ -204,16 +218,30 @@ class DifyCoreRepositoryFactory: try: repository_class = cls._import_class(class_path) cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) - cls._validate_constructor_signature( - repository_class, ["session_factory", "user", "app_id", "triggered_from"] - ) - return repository_class( # type: ignore[no-any-return] - session_factory=session_factory, - user=user, - app_id=app_id, - triggered_from=triggered_from, - ) + # Check if this is a Celery repository that needs async_timeout + is_celery_repo = "celery" in class_path.lower() + if is_celery_repo: + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from", "async_timeout"] + ) + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + async_timeout=dify_config.CELERY_REPOSITORY_ASYNC_TIMEOUT, + ) + else: + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) except RepositoryImportError: # Re-raise our custom errors as-is raise diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 18d4f4885d..b68ebdc360 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -22,7 +22,7 @@ if [[ "${MODE}" == "worker" ]]; then exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \ - -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion} + -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,workflow_storage} elif [[ "${MODE}" == "beat" ]]; then exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO} diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 6279b1ad36..77a39d617f 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -72,6 +72,8 @@ def init_app(app: DifyApp) -> Celery: "schedule.clean_messages", "schedule.mail_clean_document_notify_task", "schedule.queue_monitor_task", + "tasks.workflow_execution_tasks", + "tasks.workflow_node_execution_tasks", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py new file mode 100644 index 0000000000..048dad7831 --- /dev/null +++ b/api/tasks/workflow_execution_tasks.py @@ -0,0 +1,157 @@ +""" +Celery tasks for asynchronous workflow execution storage operations. + +These tasks provide asynchronous storage capabilities for workflow execution data, +improving performance by offloading storage operations to background workers. +""" + +import json +import logging + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db +from models import CreatorUserRole, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) +def save_workflow_execution_task( + self, + execution_data: dict, + tenant_id: str, + app_id: str, + triggered_from: str, + creator_user_id: str, + creator_user_role: str, +) -> bool: + """ + Asynchronously save or update a workflow execution to the database. + + Args: + execution_data: Serialized WorkflowExecution data + tenant_id: Tenant ID for multi-tenancy + app_id: Application ID + triggered_from: Source of the execution trigger + creator_user_id: ID of the user who created the execution + creator_user_role: Role of the user who created the execution + + Returns: + True if successful, False otherwise + """ + try: + # Create a new session for this task + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + # Deserialize execution data + execution = WorkflowExecution.model_validate(execution_data) + + # Check if workflow run already exists + existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_)) + + if existing_run: + # Update existing workflow run + _update_workflow_run_from_execution(existing_run, execution) + logger.debug(f"Updated existing workflow run: {execution.id_}") + else: + # Create new workflow run + workflow_run = _create_workflow_run_from_execution( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowRunTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + session.add(workflow_run) + logger.debug(f"Created new workflow run: {execution.id_}") + + session.commit() + return True + + except Exception as e: + logger.exception(f"Failed to save workflow execution {execution_data.get('id_', 'unknown')}: {e}") + # Retry the task with exponential backoff + raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) + + +def _create_workflow_run_from_execution( + execution: WorkflowExecution, + tenant_id: str, + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + creator_user_id: str, + creator_user_role: CreatorUserRole, +) -> WorkflowRun: + """ + Create a WorkflowRun database model from a WorkflowExecution domain entity. + """ + workflow_run = WorkflowRun() + workflow_run.id = execution.id_ + workflow_run.tenant_id = tenant_id + workflow_run.app_id = app_id + workflow_run.workflow_id = execution.workflow_id + workflow_run.type = execution.workflow_type.value + workflow_run.triggered_from = triggered_from.value + workflow_run.version = execution.workflow_version + json_converter = WorkflowRuntimeTypeConverter() + workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) + workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) + workflow_run.status = execution.status.value + workflow_run.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + workflow_run.error = execution.error_message + workflow_run.elapsed_time = execution.elapsed_time + workflow_run.total_tokens = execution.total_tokens + workflow_run.total_steps = execution.total_steps + workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by = creator_user_id + workflow_run.created_at = execution.started_at + workflow_run.finished_at = execution.finished_at + + return workflow_run + + +def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None: + """ + Update a WorkflowRun database model from a WorkflowExecution domain entity. + """ + json_converter = WorkflowRuntimeTypeConverter() + workflow_run.status = execution.status.value + workflow_run.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + workflow_run.error = execution.error_message + workflow_run.elapsed_time = execution.elapsed_time + workflow_run.total_tokens = execution.total_tokens + workflow_run.total_steps = execution.total_steps + workflow_run.finished_at = execution.finished_at + + +def _create_execution_from_workflow_run(workflow_run: WorkflowRun) -> WorkflowExecution: + """ + Create a WorkflowExecution domain entity from a WorkflowRun database model. + """ + return WorkflowExecution( + id_=workflow_run.id, + workflow_id=workflow_run.workflow_id, + workflow_type=WorkflowType(workflow_run.type), + workflow_version=workflow_run.version, + graph=json.loads(workflow_run.graph or "{}"), + inputs=json.loads(workflow_run.inputs or "{}"), + outputs=json.loads(workflow_run.outputs or "{}"), + status=WorkflowExecutionStatus(workflow_run.status), + error_message=workflow_run.error or "", + total_tokens=workflow_run.total_tokens, + total_steps=workflow_run.total_steps, + started_at=workflow_run.created_at, + finished_at=workflow_run.finished_at, + ) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py new file mode 100644 index 0000000000..6cf1039e4c --- /dev/null +++ b/api/tasks/workflow_node_execution_tasks.py @@ -0,0 +1,277 @@ +""" +Celery tasks for asynchronous workflow node execution storage operations. + +These tasks provide asynchronous storage capabilities for workflow node execution data, +improving performance by offloading storage operations to background workers. +""" + +import json +import logging +from typing import Optional + +from celery import shared_task +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from extensions.ext_database import db +from models import CreatorUserRole, WorkflowNodeExecutionModel +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) +def save_workflow_node_execution_task( + self, + execution_data: dict, + tenant_id: str, + app_id: str, + triggered_from: str, + creator_user_id: str, + creator_user_role: str, +) -> bool: + """ + Asynchronously save or update a workflow node execution to the database. + + Args: + execution_data: Serialized WorkflowNodeExecution data + tenant_id: Tenant ID for multi-tenancy + app_id: Application ID + triggered_from: Source of the execution trigger + creator_user_id: ID of the user who created the execution + creator_user_role: Role of the user who created the execution + + Returns: + True if successful, False otherwise + """ + try: + # Create a new session for this task + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + # Deserialize execution data + execution = WorkflowNodeExecution.model_validate(execution_data) + + # Check if node execution already exists + existing_execution = session.scalar( + select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id) + ) + + if existing_execution: + # Update existing node execution + _update_node_execution_from_domain(existing_execution, execution) + logger.debug(f"Updated existing workflow node execution: {execution.id}") + else: + # Create new node execution + node_execution = _create_node_execution_from_domain( + execution=execution, + tenant_id=tenant_id, + app_id=app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from), + creator_user_id=creator_user_id, + creator_user_role=CreatorUserRole(creator_user_role), + ) + session.add(node_execution) + logger.debug(f"Created new workflow node execution: {execution.id}") + + session.commit() + return True + + except Exception as e: + logger.exception(f"Failed to save workflow node execution {execution_data.get('id', 'unknown')}: {e}") + # Retry the task with exponential backoff + raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) + + +@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60) +def get_workflow_node_executions_by_workflow_run_task( + self, + workflow_run_id: str, + tenant_id: str, + app_id: str, + order_config: Optional[dict] = None, +) -> list[dict]: + """ + Asynchronously retrieve all workflow node executions for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + tenant_id: Tenant ID for multi-tenancy + app_id: Application ID + order_config: Optional ordering configuration + + Returns: + List of serialized WorkflowNodeExecution data + """ + try: + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + + with session_factory() as session: + # Build base query + query = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + + # Apply ordering if specified + if order_config: + order_obj = OrderConfig( + order_by=order_config["order_by"], order_direction=order_config.get("order_direction") + ) + for field_name in order_obj.order_by: + field = getattr(WorkflowNodeExecutionModel, field_name, None) + if field is not None: + if order_obj.order_direction == "desc": + query = query.order_by(field.desc()) + else: + query = query.order_by(field.asc()) + + node_executions = session.scalars(query).all() + + result = [] + for node_execution in node_executions: + execution = _create_domain_from_node_execution(node_execution) + result.append(execution.model_dump()) + + return result + + except Exception as e: + logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}") + # Retry the task with exponential backoff + raise self.retry(exc=e, countdown=60 * (2**self.request.retries)) + + +def _create_node_execution_from_domain( + execution: WorkflowNodeExecution, + tenant_id: str, + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + creator_user_id: str, + creator_user_role: CreatorUserRole, +) -> WorkflowNodeExecutionModel: + """ + Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. + """ + node_execution = WorkflowNodeExecutionModel() + node_execution.id = execution.id + node_execution.tenant_id = tenant_id + node_execution.app_id = app_id + node_execution.workflow_id = execution.workflow_id + node_execution.triggered_from = triggered_from.value + node_execution.workflow_run_id = execution.workflow_execution_id + node_execution.index = execution.index + node_execution.predecessor_node_id = execution.predecessor_node_id + node_execution.node_id = execution.node_id + node_execution.node_type = execution.node_type.value + node_execution.title = execution.title + node_execution.node_execution_id = execution.node_execution_id + + # Serialize complex data as JSON + json_converter = WorkflowRuntimeTypeConverter() + node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" + node_execution.process_data = ( + json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" + ) + node_execution.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + # Convert metadata enum keys to strings for JSON serialization + if execution.metadata: + metadata_for_json = { + key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() + } + node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) + else: + node_execution.execution_metadata = "{}" + + node_execution.status = execution.status.value + node_execution.error = execution.error + node_execution.elapsed_time = execution.elapsed_time + node_execution.created_by_role = creator_user_role.value + node_execution.created_by = creator_user_id + node_execution.created_at = execution.created_at + node_execution.finished_at = execution.finished_at + + return node_execution + + +def _update_node_execution_from_domain( + node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution +) -> None: + """ + Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity. + """ + # Update serialized data + json_converter = WorkflowRuntimeTypeConverter() + node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}" + node_execution.process_data = ( + json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}" + ) + node_execution.outputs = ( + json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" + ) + # Convert metadata enum keys to strings for JSON serialization + if execution.metadata: + metadata_for_json = { + key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items() + } + node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json)) + else: + node_execution.execution_metadata = "{}" + + # Update other fields + node_execution.status = execution.status.value + node_execution.error = execution.error + node_execution.elapsed_time = execution.elapsed_time + node_execution.finished_at = execution.finished_at + + +def _create_domain_from_node_execution(node_execution: WorkflowNodeExecutionModel) -> WorkflowNodeExecution: + """ + Create a WorkflowNodeExecution domain entity from a WorkflowNodeExecutionModel database model. + """ + # Deserialize JSON data + inputs = json.loads(node_execution.inputs or "{}") + process_data = json.loads(node_execution.process_data or "{}") + outputs = json.loads(node_execution.outputs or "{}") + metadata = json.loads(node_execution.execution_metadata or "{}") + + # Convert metadata keys to enum values + typed_metadata = {} + for key, value in metadata.items(): + try: + enum_key = WorkflowNodeExecutionMetadataKey(key) + typed_metadata[enum_key] = value + except ValueError: + # Skip unknown metadata keys + continue + + return WorkflowNodeExecution( + id=node_execution.id, + node_execution_id=node_execution.node_execution_id, + workflow_id=node_execution.workflow_id, + workflow_execution_id=node_execution.workflow_run_id, + index=node_execution.index, + predecessor_node_id=node_execution.predecessor_node_id, + node_id=node_execution.node_id, + node_type=NodeType(node_execution.node_type), + title=node_execution.title, + inputs=inputs if inputs else None, + process_data=process_data if process_data else None, + outputs=outputs if outputs else None, + status=WorkflowNodeExecutionStatus(node_execution.status), + error=node_execution.error, + elapsed_time=node_execution.elapsed_time, + metadata=typed_metadata if typed_metadata else None, + created_at=node_execution.created_at, + finished_at=node_execution.finished_at, + ) diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py new file mode 100644 index 0000000000..a63d584419 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -0,0 +1,237 @@ +""" +Unit tests for CeleryWorkflowExecutionRepository. + +These tests verify the Celery-based asynchronous storage functionality +for workflow execution data. +""" + +from datetime import UTC, datetime +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from celery.result import AsyncResult + +from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository +from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom + + +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # Create a real sessionmaker with in-memory SQLite for testing + engine = create_engine("sqlite:///:memory:") + return sessionmaker(bind=engine) + + +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = Mock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = Mock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution.new( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + +class TestCeleryWorkflowExecutionRepository: + """Test cases for CeleryWorkflowExecutionRepository.""" + + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + """Test repository initialization with sessionmaker.""" + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=app_id, + triggered_from=triggered_from, + ) + + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._async_timeout == 30 # default timeout + + def test_init_with_custom_timeout(self, mock_session_factory, mock_account): + """Test repository initialization with custom timeout.""" + custom_timeout = 60 + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + async_timeout=custom_timeout, + ) + + assert repo._async_timeout == custom_timeout + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + """Test repository initialization with EndUser.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_end_user, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert repo._tenant_id == mock_end_user.tenant_id + + def test_init_without_tenant_id_raises_error(self, mock_session_factory): + """Test that initialization fails without tenant_id.""" + user = Mock() + user.current_tenant_id = None + user.tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=user, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution): + """Test that save operation queues a Celery task.""" + mock_result = Mock(spec=AsyncResult) + mock_task.delay.return_value = mock_result + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + # Verify Celery task was queued with correct parameters + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[1] + + assert call_args["execution_data"] == sample_workflow_execution.model_dump() + assert call_args["tenant_id"] == mock_account.current_tenant_id + assert call_args["app_id"] == "test-app" + assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value + assert call_args["creator_user_id"] == mock_account.id + + # Verify task result is stored for tracking + assert sample_workflow_execution.id_ in repo._pending_saves + assert repo._pending_saves[sample_workflow_execution.id_] == mock_result + + @patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task") + def test_save_handles_celery_failure( + self, mock_task, mock_session_factory, mock_account, sample_workflow_execution + ): + """Test that save operation handles Celery task failures.""" + mock_task.delay.side_effect = Exception("Celery is down") + + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + with pytest.raises(Exception, match="Celery is down"): + repo.save(sample_workflow_execution) + + + def test_wait_for_pending_saves(self, mock_session_factory, mock_account, sample_workflow_execution): + """Test waiting for all pending save operations.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Add some mock pending saves + mock_result1 = Mock(spec=AsyncResult) + mock_result1.ready.return_value = False + mock_result2 = Mock(spec=AsyncResult) + mock_result2.ready.return_value = True + + repo._pending_saves["exec1"] = mock_result1 + repo._pending_saves["exec2"] = mock_result2 + + repo.wait_for_pending_saves(timeout=10) + + # Verify that non-ready task was waited for + mock_result1.get.assert_called_once_with(timeout=10) + + # Verify pending saves were cleared + assert len(repo._pending_saves) == 0 + + def test_get_pending_save_count(self, mock_session_factory, mock_account): + """Test getting the count of pending save operations.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Add some mock pending saves + mock_result1 = Mock(spec=AsyncResult) + mock_result1.ready.return_value = False + mock_result2 = Mock(spec=AsyncResult) + mock_result2.ready.return_value = True + + repo._pending_saves["exec1"] = mock_result1 + repo._pending_saves["exec2"] = mock_result2 + + count = repo.get_pending_save_count() + + # Should clean up completed tasks and return count of remaining + assert count == 1 + assert "exec1" in repo._pending_saves + assert "exec2" not in repo._pending_saves + + def test_clear_pending_saves(self, mock_session_factory, mock_account): + """Test clearing all pending save operations.""" + repo = CeleryWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Add some mock pending saves + repo._pending_saves["exec1"] = Mock(spec=AsyncResult) + repo._pending_saves["exec2"] = Mock(spec=AsyncResult) + + repo.clear_pending_saves() + + assert len(repo._pending_saves) == 0 diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py new file mode 100644 index 0000000000..57ff558ca5 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -0,0 +1,308 @@ +""" +Unit tests for CeleryWorkflowNodeExecutionRepository. + +These tests verify the Celery-based asynchronous storage functionality +for workflow node execution data. +""" + +from datetime import UTC, datetime +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from celery.result import AsyncResult + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from models import Account, EndUser +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # Create a real sessionmaker with in-memory SQLite for testing + engine = create_engine("sqlite:///:memory:") + return sessionmaker(bind=engine) + + +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = Mock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = Mock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_node_execution(): + """Sample WorkflowNodeExecution for testing.""" + return WorkflowNodeExecution( + id=str(uuid4()), + node_execution_id=str(uuid4()), + workflow_id=str(uuid4()), + workflow_execution_id=str(uuid4()), + index=1, + node_id="test_node", + node_type=NodeType.START, + title="Test Node", + inputs={"input1": "value1"}, + status=WorkflowNodeExecutionStatus.RUNNING, + created_at=datetime.now(UTC).replace(tzinfo=None), + ) + + +class TestCeleryWorkflowNodeExecutionRepository: + """Test cases for CeleryWorkflowNodeExecutionRepository.""" + + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + """Test repository initialization with sessionmaker.""" + app_id = "test-app-id" + triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=app_id, + triggered_from=triggered_from, + ) + + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._async_timeout == 30 # default timeout + + def test_init_with_custom_timeout(self, mock_session_factory, mock_account): + """Test repository initialization with custom timeout.""" + custom_timeout = 60 + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + async_timeout=custom_timeout, + ) + + assert repo._async_timeout == custom_timeout + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + """Test repository initialization with EndUser.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_end_user, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._tenant_id == mock_end_user.tenant_id + + def test_init_without_tenant_id_raises_error(self, mock_session_factory): + """Test that initialization fails without tenant_id.""" + user = Mock() + user.current_tenant_id = None + user.tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=user, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_save_queues_celery_task( + self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution + ): + """Test that save operation queues a Celery task.""" + mock_result = Mock(spec=AsyncResult) + mock_task.delay.return_value = mock_result + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + repo.save(sample_workflow_node_execution) + + # Verify Celery task was queued with correct parameters + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[1] + + assert call_args["execution_data"] == sample_workflow_node_execution.model_dump() + assert call_args["tenant_id"] == mock_account.current_tenant_id + assert call_args["app_id"] == "test-app" + assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value + assert call_args["creator_user_id"] == mock_account.id + + # Verify task result is stored for tracking + assert sample_workflow_node_execution.id in repo._pending_saves + assert repo._pending_saves[sample_workflow_node_execution.id] == mock_result + + @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") + def test_save_handles_celery_failure( + self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution + ): + """Test that save operation handles Celery task failures.""" + mock_task.delay.side_effect = Exception("Celery is down") + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + with pytest.raises(Exception, match="Celery is down"): + repo.save(sample_workflow_node_execution) + + + @patch( + "core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task" + ) + def test_get_by_workflow_run(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution): + """Test that get_by_workflow_run retrieves all executions for a workflow run.""" + executions_data = [sample_workflow_node_execution.model_dump()] + mock_result = Mock(spec=AsyncResult) + mock_result.get.return_value = executions_data + mock_task.delay.return_value = mock_result + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + workflow_run_id = sample_workflow_node_execution.workflow_execution_id + order_config = OrderConfig(order_by=["index"], order_direction="asc") + + result = repo.get_by_workflow_run(workflow_run_id, order_config) + + # Verify Celery task was queued with correct parameters + mock_task.delay.assert_called_once_with( + workflow_run_id=workflow_run_id, + tenant_id=mock_account.current_tenant_id, + app_id="test-app", + order_config={"order_by": order_config.order_by, "order_direction": order_config.order_direction}, + ) + + # Verify results were properly deserialized + assert len(result) == 1 + assert result[0].id == sample_workflow_node_execution.id + + @patch( + "core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task" + ) + def test_get_by_workflow_run_without_order_config(self, mock_task, mock_session_factory, mock_account): + """Test get_by_workflow_run without order configuration.""" + mock_result = Mock(spec=AsyncResult) + mock_result.get.return_value = [] + mock_task.delay.return_value = mock_result + + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + result = repo.get_by_workflow_run("workflow-run-id") + + # Verify order_config was passed as None + call_args = mock_task.delay.call_args[1] + assert call_args["order_config"] is None + + assert len(result) == 0 + + + def test_wait_for_pending_saves(self, mock_session_factory, mock_account): + """Test waiting for all pending save operations.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Add some mock pending saves + mock_result1 = Mock(spec=AsyncResult) + mock_result1.ready.return_value = False + mock_result2 = Mock(spec=AsyncResult) + mock_result2.ready.return_value = True + + repo._pending_saves["exec1"] = mock_result1 + repo._pending_saves["exec2"] = mock_result2 + + repo.wait_for_pending_saves(timeout=10) + + # Verify that non-ready task was waited for + mock_result1.get.assert_called_once_with(timeout=10) + + # Verify pending saves were cleared + assert len(repo._pending_saves) == 0 + + def test_get_pending_save_count(self, mock_session_factory, mock_account): + """Test getting the count of pending save operations.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Add some mock pending saves + mock_result1 = Mock(spec=AsyncResult) + mock_result1.ready.return_value = False + mock_result2 = Mock(spec=AsyncResult) + mock_result2.ready.return_value = True + + repo._pending_saves["exec1"] = mock_result1 + repo._pending_saves["exec2"] = mock_result2 + + count = repo.get_pending_save_count() + + # Should clean up completed tasks and return count of remaining + assert count == 1 + assert "exec1" in repo._pending_saves + assert "exec2" not in repo._pending_saves + + def test_clear_pending_saves(self, mock_session_factory, mock_account): + """Test clearing all pending save operations.""" + repo = CeleryWorkflowNodeExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test-app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + # Add some mock pending saves + repo._pending_saves["exec1"] = Mock(spec=AsyncResult) + repo._pending_saves["exec2"] = Mock(spec=AsyncResult) + + repo.clear_pending_saves() + + assert len(repo._pending_saves) == 0 + diff --git a/dev/start-worker b/dev/start-worker index 7007b265e0..972ea5fea5 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.." uv --directory api run \ celery -A app.celery worker \ - -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion + -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,workflow_storage diff --git a/docker/.env.example b/docker/.env.example index 6149f63165..262209a2ea 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -811,16 +811,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms # Repository configuration # Core workflow execution repository implementation +# Options: +# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default) +# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository # Core workflow node execution repository implementation +# Options: +# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default) +# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # API workflow node execution repository implementation API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository -# API workflow run repository implementation -API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository +# Celery repository configuration +# Timeout in seconds for Celery repository async operations +CELERY_REPOSITORY_ASYNC_TIMEOUT=30 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1271d6d464..ebd056af2c 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,7 @@ x-shared-env: &shared-api-worker-env CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} + CELERY_REPOSITORY_ASYNC_TIMEOUT: ${CELERY_REPOSITORY_ASYNC_TIMEOUT:-30} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}