add celery based exec repos

pull/20050/merge^2
liangxin 7 months ago
parent 74981a65c6
commit 2f7dc7a58a

@ -552,12 +552,16 @@ class RepositoryConfig(BaseSettings):
"""
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path",
description="Repository implementation for WorkflowExecution. Options: "
"'core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository' (default), "
"'core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
)
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
description="Repository implementation for WorkflowNodeExecution. Options: "
"'core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository' (default), "
"'core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository'",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
)
@ -572,6 +576,12 @@ class RepositoryConfig(BaseSettings):
default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
)
# Celery repository configuration
CELERY_REPOSITORY_ASYNC_TIMEOUT: int = Field(
description="Timeout in seconds for Celery repository async operations",
default=30,
)
class AuthConfig(BaseSettings):
"""

@ -5,10 +5,14 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",

@ -0,0 +1,180 @@
"""
Celery-based implementation of the WorkflowExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from typing import Optional, Union
from celery.result import AsyncResult
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.enums import WorkflowRunTriggeredFrom
from tasks.workflow_execution_tasks import (
save_workflow_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
"""
Celery-based implementation of the WorkflowExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Fallback to synchronous operations for read operations when immediate results are needed
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
- Configurable timeouts for async operations
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowRunTriggeredFrom],
async_timeout: int = 30,
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
async_timeout: Timeout in seconds for async operations (default: 30)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Async operation timeout
self._async_timeout = async_timeout
# Cache for pending async operations
self._pending_saves: dict[str, AsyncResult] = {}
logger.info(
f"Initialized CeleryWorkflowExecutionRepository for tenant {self._tenant_id}, "
f"app {self._app_id}, triggered_from {self._triggered_from}"
)
def save(self, execution: WorkflowExecution) -> None:
"""
Save or update a WorkflowExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task
task_result = save_workflow_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
# Store the task result for potential status checking
self._pending_saves[execution.id_] = task_result
logger.debug(f"Queued async save for workflow execution: {execution.id_}")
except Exception as e:
logger.exception(f"Failed to queue save operation for execution {execution.id_}: {e}")
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None:
"""
Wait for all pending save operations to complete.
This method is useful for ensuring data consistency when immediate
persistence is required (e.g., during testing or critical operations).
Args:
timeout: Maximum time to wait for all operations (uses instance timeout if None)
"""
wait_timeout = timeout or self._async_timeout
for execution_id, task_result in list(self._pending_saves.items()):
try:
if not task_result.ready():
logger.debug(f"Waiting for save operation to complete: {execution_id}")
task_result.get(timeout=wait_timeout)
# Remove completed task
del self._pending_saves[execution_id]
except Exception as e:
logger.exception(f"Failed to wait for save operation {execution_id}: {e}")
def get_pending_save_count(self) -> int:
"""
Get the number of pending save operations.
Returns:
Number of save operations still in progress
"""
# Clean up completed tasks
completed_ids = []
for execution_id, task_result in self._pending_saves.items():
if task_result.ready():
completed_ids.append(execution_id)
for execution_id in completed_ids:
del self._pending_saves[execution_id]
return len(self._pending_saves)
def clear_pending_saves(self) -> None:
"""
Clear all pending save operations without waiting for completion.
This method is useful for cleanup operations or when canceling workflows.
"""
self._pending_saves.clear()
logger.debug("Cleared all pending save operations")

@ -0,0 +1,275 @@
"""
Celery-based implementation of the WorkflowNodeExecutionRepository.
This implementation uses Celery tasks for asynchronous storage operations,
providing improved performance by offloading database operations to background workers.
"""
import logging
from collections.abc import Sequence
from typing import Optional, Union
from celery.result import AsyncResult
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.repositories.workflow_node_execution_repository import (
OrderConfig,
WorkflowNodeExecutionRepository,
)
from libs.helper import extract_tenant_id
from models import Account, CreatorUserRole, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
from tasks.workflow_node_execution_tasks import (
get_workflow_node_executions_by_workflow_run_task,
save_workflow_node_execution_task,
)
logger = logging.getLogger(__name__)
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
This implementation provides asynchronous storage capabilities by using Celery tasks
to handle database operations in background workers. This improves performance by
reducing the blocking time for workflow node execution storage operations.
Key features:
- Asynchronous save operations using Celery tasks
- Fallback to synchronous operations for read operations when immediate results are needed
- Support for multi-tenancy through tenant/app filtering
- Automatic retry and error handling through Celery
- Configurable timeouts for async operations
- Batch operations for improved efficiency
"""
def __init__(
self,
session_factory: sessionmaker | Engine,
user: Union[Account, EndUser],
app_id: Optional[str],
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
async_timeout: int = 30,
):
"""
Initialize the repository with Celery task configuration and context information.
Args:
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
user: Account or EndUser object containing tenant_id, user ID, and role information
app_id: App ID for filtering by application (can be None)
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
async_timeout: Timeout in seconds for async operations (default: 30)
"""
# Store session factory for fallback operations
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
else:
raise ValueError(
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
)
# Extract tenant_id from user
tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
# Store app context
self._app_id = app_id
# Extract user context
self._triggered_from = triggered_from
self._creator_user_id = user.id
# Determine user role based on user type
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Async operation timeout
self._async_timeout = async_timeout
# Cache for pending async operations
self._pending_saves: dict[str, AsyncResult] = {}
# Cache for mapping execution IDs to workflow_execution_ids for efficient workflow-specific waiting
self._workflow_execution_mapping: dict[str, str] = {}
logger.info(
f"Initialized CeleryWorkflowNodeExecutionRepository for tenant {self._tenant_id}, "
f"app {self._app_id}, triggered_from {self._triggered_from}"
)
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save or update a WorkflowNodeExecution instance asynchronously using Celery.
This method queues the save operation as a Celery task and returns immediately,
providing improved performance for high-throughput scenarios.
Args:
execution: The WorkflowNodeExecution instance to save or update
"""
try:
# Serialize execution for Celery task
execution_data = execution.model_dump()
# Queue the save operation as a Celery task
task_result = save_workflow_node_execution_task.delay(
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
triggered_from=self._triggered_from.value if self._triggered_from else "",
creator_user_id=self._creator_user_id,
creator_user_role=self._creator_user_role.value,
)
# Store the task result for potential status checking
self._pending_saves[execution.id] = task_result
# Cache the workflow_execution_id mapping for efficient workflow-specific waiting
self._workflow_execution_mapping[execution.id] = execution.workflow_execution_id
logger.debug(f"Queued async save for workflow node execution: {execution.id}")
except Exception as e:
logger.exception(f"Failed to queue save operation for node execution {execution.id}: {e}")
# In case of Celery failure, we could implement a fallback to synchronous save
# For now, we'll re-raise the exception
raise
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
Returns:
A sequence of WorkflowNodeExecution instances
"""
try:
# Wait for any pending saves that might affect this workflow run
self._wait_for_pending_saves_by_workflow_run(workflow_run_id)
# Serialize order config for Celery task
if order_config:
order_config_data = {"order_by": order_config.order_by, "order_direction": order_config.order_direction}
else:
order_config_data = None
# Queue the get operation as a Celery task
task_result = get_workflow_node_executions_by_workflow_run_task.delay(
workflow_run_id=workflow_run_id,
tenant_id=self._tenant_id,
app_id=self._app_id or "",
order_config=order_config_data,
)
# Wait for the result (synchronous for read operations)
executions_data = task_result.get(timeout=self._async_timeout)
result = []
for execution_data in executions_data:
execution = WorkflowNodeExecution.model_validate(execution_data)
result.append(execution)
return result
except Exception as e:
logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}")
# Could implement fallback to direct database access here
return []
def _wait_for_pending_saves_by_workflow_run(self, workflow_run_id: str) -> None:
"""
Wait for any pending save operations that might affect the given workflow run.
This method now uses the cached workflow_execution_id mapping to only wait for
tasks that belong to the specific workflow run, improving efficiency.
Args:
workflow_run_id: The workflow run ID to check
"""
# Find execution IDs that belong to this workflow run
relevant_execution_ids = [
execution_id for execution_id, cached_workflow_id in self._workflow_execution_mapping.items()
if cached_workflow_id == workflow_run_id and execution_id in self._pending_saves
]
logger.debug(f"Found {len(relevant_execution_ids)} pending saves for workflow run {workflow_run_id}")
for execution_id in relevant_execution_ids:
task_result = self._pending_saves.get(execution_id)
if task_result and not task_result.ready():
try:
logger.debug(f"Waiting for pending save to complete before read: {execution_id}")
task_result.get(timeout=self._async_timeout)
except Exception as e:
logger.exception(f"Failed to wait for pending save {execution_id}: {e}")
# Clean up completed tasks from both caches
if task_result and task_result.ready():
self._pending_saves.pop(execution_id, None)
self._workflow_execution_mapping.pop(execution_id, None)
def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None:
"""
Wait for all pending save operations to complete.
This method is useful for ensuring data consistency when immediate
persistence is required (e.g., during testing or critical operations).
Args:
timeout: Maximum time to wait for all operations (uses instance timeout if None)
"""
wait_timeout = timeout or self._async_timeout
for execution_id, task_result in list(self._pending_saves.items()):
try:
if not task_result.ready():
logger.debug(f"Waiting for save operation to complete: {execution_id}")
task_result.get(timeout=wait_timeout)
# Remove completed task from both caches
del self._pending_saves[execution_id]
self._workflow_execution_mapping.pop(execution_id, None)
except Exception as e:
logger.exception(f"Failed to wait for save operation {execution_id}: {e}")
def get_pending_save_count(self) -> int:
"""
Get the number of pending save operations.
Returns:
Number of save operations still in progress
"""
# Clean up completed tasks
completed_ids = []
for execution_id, task_result in self._pending_saves.items():
if task_result.ready():
completed_ids.append(execution_id)
for execution_id in completed_ids:
del self._pending_saves[execution_id]
self._workflow_execution_mapping.pop(execution_id, None)
return len(self._pending_saves)
def clear_pending_saves(self) -> None:
"""
Clear all pending save operations without waiting for completion.
This method is useful for cleanup operations or when canceling workflows.
"""
self._pending_saves.clear()
self._workflow_execution_mapping.clear()
logger.debug("Cleared all pending save operations and workflow execution mappings")

@ -158,16 +158,30 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
# Check if this is a Celery repository that needs async_timeout
is_celery_repo = "celery" in class_path.lower()
if is_celery_repo:
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from", "async_timeout"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
async_timeout=dify_config.CELERY_REPOSITORY_ASYNC_TIMEOUT,
)
else:
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
except RepositoryImportError:
# Re-raise our custom errors as-is
raise
@ -204,16 +218,30 @@ class DifyCoreRepositoryFactory:
try:
repository_class = cls._import_class(class_path)
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
# Check if this is a Celery repository that needs async_timeout
is_celery_repo = "celery" in class_path.lower()
if is_celery_repo:
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from", "async_timeout"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
async_timeout=dify_config.CELERY_REPOSITORY_ASYNC_TIMEOUT,
)
else:
cls._validate_constructor_signature(
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
)
return repository_class( # type: ignore[no-any-return]
session_factory=session_factory,
user=user,
app_id=app_id,
triggered_from=triggered_from,
)
except RepositoryImportError:
# Re-raise our custom errors as-is
raise

@ -22,7 +22,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,workflow_storage}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

@ -72,6 +72,8 @@ def init_app(app: DifyApp) -> Celery:
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
"schedule.queue_monitor_task",
"tasks.workflow_execution_tasks",
"tasks.workflow_node_execution_tasks",
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {

@ -0,0 +1,157 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow execution to the database.
Args:
execution_data: Serialized WorkflowExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)
# Check if workflow run already exists
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
if existing_run:
# Update existing workflow run
_update_workflow_run_from_execution(existing_run, execution)
logger.debug(f"Updated existing workflow run: {execution.id_}")
else:
# Create new workflow run
workflow_run = _create_workflow_run_from_execution(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(workflow_run)
logger.debug(f"Created new workflow run: {execution.id_}")
session.commit()
return True
except Exception as e:
logger.exception(f"Failed to save workflow execution {execution_data.get('id_', 'unknown')}: {e}")
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_workflow_run_from_execution(
execution: WorkflowExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowRunTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowRun:
"""
Create a WorkflowRun database model from a WorkflowExecution domain entity.
"""
workflow_run = WorkflowRun()
workflow_run.id = execution.id_
workflow_run.tenant_id = tenant_id
workflow_run.app_id = app_id
workflow_run.workflow_id = execution.workflow_id
workflow_run.type = execution.workflow_type.value
workflow_run.triggered_from = triggered_from.value
workflow_run.version = execution.workflow_version
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.created_by_role = creator_user_role.value
workflow_run.created_by = creator_user_id
workflow_run.created_at = execution.started_at
workflow_run.finished_at = execution.finished_at
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""
json_converter = WorkflowRuntimeTypeConverter()
workflow_run.status = execution.status.value
workflow_run.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
workflow_run.error = execution.error_message
workflow_run.elapsed_time = execution.elapsed_time
workflow_run.total_tokens = execution.total_tokens
workflow_run.total_steps = execution.total_steps
workflow_run.finished_at = execution.finished_at
def _create_execution_from_workflow_run(workflow_run: WorkflowRun) -> WorkflowExecution:
"""
Create a WorkflowExecution domain entity from a WorkflowRun database model.
"""
return WorkflowExecution(
id_=workflow_run.id,
workflow_id=workflow_run.workflow_id,
workflow_type=WorkflowType(workflow_run.type),
workflow_version=workflow_run.version,
graph=json.loads(workflow_run.graph or "{}"),
inputs=json.loads(workflow_run.inputs or "{}"),
outputs=json.loads(workflow_run.outputs or "{}"),
status=WorkflowExecutionStatus(workflow_run.status),
error_message=workflow_run.error or "",
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
started_at=workflow_run.created_at,
finished_at=workflow_run.finished_at,
)

@ -0,0 +1,277 @@
"""
Celery tasks for asynchronous workflow node execution storage operations.
These tasks provide asynchronous storage capabilities for workflow node execution data,
improving performance by offloading storage operations to background workers.
"""
import json
import logging
from typing import Optional
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_node_execution_task(
self,
execution_data: dict,
tenant_id: str,
app_id: str,
triggered_from: str,
creator_user_id: str,
creator_user_role: str,
) -> bool:
"""
Asynchronously save or update a workflow node execution to the database.
Args:
execution_data: Serialized WorkflowNodeExecution data
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
triggered_from: Source of the execution trigger
creator_user_id: ID of the user who created the execution
creator_user_role: Role of the user who created the execution
Returns:
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)
# Check if node execution already exists
existing_execution = session.scalar(
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
)
if existing_execution:
# Update existing node execution
_update_node_execution_from_domain(existing_execution, execution)
logger.debug(f"Updated existing workflow node execution: {execution.id}")
else:
# Create new node execution
node_execution = _create_node_execution_from_domain(
execution=execution,
tenant_id=tenant_id,
app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
creator_user_id=creator_user_id,
creator_user_role=CreatorUserRole(creator_user_role),
)
session.add(node_execution)
logger.debug(f"Created new workflow node execution: {execution.id}")
session.commit()
return True
except Exception as e:
logger.exception(f"Failed to save workflow node execution {execution_data.get('id', 'unknown')}: {e}")
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
def get_workflow_node_executions_by_workflow_run_task(
self,
workflow_run_id: str,
tenant_id: str,
app_id: str,
order_config: Optional[dict] = None,
) -> list[dict]:
"""
Asynchronously retrieve all workflow node executions for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
tenant_id: Tenant ID for multi-tenancy
app_id: Application ID
order_config: Optional ordering configuration
Returns:
List of serialized WorkflowNodeExecution data
"""
try:
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
# Build base query
query = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
)
# Apply ordering if specified
if order_config:
order_obj = OrderConfig(
order_by=order_config["order_by"], order_direction=order_config.get("order_direction")
)
for field_name in order_obj.order_by:
field = getattr(WorkflowNodeExecutionModel, field_name, None)
if field is not None:
if order_obj.order_direction == "desc":
query = query.order_by(field.desc())
else:
query = query.order_by(field.asc())
node_executions = session.scalars(query).all()
result = []
for node_execution in node_executions:
execution = _create_domain_from_node_execution(node_execution)
result.append(execution.model_dump())
return result
except Exception as e:
logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}")
# Retry the task with exponential backoff
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
def _create_node_execution_from_domain(
execution: WorkflowNodeExecution,
tenant_id: str,
app_id: str,
triggered_from: WorkflowNodeExecutionTriggeredFrom,
creator_user_id: str,
creator_user_role: CreatorUserRole,
) -> WorkflowNodeExecutionModel:
"""
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
node_execution = WorkflowNodeExecutionModel()
node_execution.id = execution.id
node_execution.tenant_id = tenant_id
node_execution.app_id = app_id
node_execution.workflow_id = execution.workflow_id
node_execution.triggered_from = triggered_from.value
node_execution.workflow_run_id = execution.workflow_execution_id
node_execution.index = execution.index
node_execution.predecessor_node_id = execution.predecessor_node_id
node_execution.node_id = execution.node_id
node_execution.node_type = execution.node_type.value
node_execution.title = execution.title
node_execution.node_execution_id = execution.node_execution_id
# Serialize complex data as JSON
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.created_by_role = creator_user_role.value
node_execution.created_by = creator_user_id
node_execution.created_at = execution.created_at
node_execution.finished_at = execution.finished_at
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""
# Update serialized data
json_converter = WorkflowRuntimeTypeConverter()
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
node_execution.process_data = (
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
)
node_execution.outputs = (
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
)
# Convert metadata enum keys to strings for JSON serialization
if execution.metadata:
metadata_for_json = {
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
}
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
else:
node_execution.execution_metadata = "{}"
# Update other fields
node_execution.status = execution.status.value
node_execution.error = execution.error
node_execution.elapsed_time = execution.elapsed_time
node_execution.finished_at = execution.finished_at
def _create_domain_from_node_execution(node_execution: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
"""
Create a WorkflowNodeExecution domain entity from a WorkflowNodeExecutionModel database model.
"""
# Deserialize JSON data
inputs = json.loads(node_execution.inputs or "{}")
process_data = json.loads(node_execution.process_data or "{}")
outputs = json.loads(node_execution.outputs or "{}")
metadata = json.loads(node_execution.execution_metadata or "{}")
# Convert metadata keys to enum values
typed_metadata = {}
for key, value in metadata.items():
try:
enum_key = WorkflowNodeExecutionMetadataKey(key)
typed_metadata[enum_key] = value
except ValueError:
# Skip unknown metadata keys
continue
return WorkflowNodeExecution(
id=node_execution.id,
node_execution_id=node_execution.node_execution_id,
workflow_id=node_execution.workflow_id,
workflow_execution_id=node_execution.workflow_run_id,
index=node_execution.index,
predecessor_node_id=node_execution.predecessor_node_id,
node_id=node_execution.node_id,
node_type=NodeType(node_execution.node_type),
title=node_execution.title,
inputs=inputs if inputs else None,
process_data=process_data if process_data else None,
outputs=outputs if outputs else None,
status=WorkflowNodeExecutionStatus(node_execution.status),
error=node_execution.error,
elapsed_time=node_execution.elapsed_time,
metadata=typed_metadata if typed_metadata else None,
created_at=node_execution.created_at,
finished_at=node_execution.finished_at,
)

@ -0,0 +1,237 @@
"""
Unit tests for CeleryWorkflowExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from celery.result import AsyncResult
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
from models import Account, EndUser
from models.enums import WorkflowRunTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_execution():
"""Sample WorkflowExecution for testing."""
return WorkflowExecution.new(
id_=str(uuid4()),
workflow_id=str(uuid4()),
workflow_type=WorkflowType.WORKFLOW,
workflow_version="1.0",
graph={"nodes": [], "edges": []},
inputs={"input1": "value1"},
started_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowExecutionRepository:
"""Test cases for CeleryWorkflowExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._async_timeout == 30 # default timeout
def test_init_with_custom_timeout(self, mock_session_factory, mock_account):
"""Test repository initialization with custom timeout."""
custom_timeout = 60
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
async_timeout=custom_timeout,
)
assert repo._async_timeout == custom_timeout
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
user = Mock()
user.current_tenant_id = None
user.tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
"""Test that save operation queues a Celery task."""
mock_result = Mock(spec=AsyncResult)
mock_task.delay.return_value = mock_result
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
repo.save(sample_workflow_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify task result is stored for tracking
assert sample_workflow_execution.id_ in repo._pending_saves
assert repo._pending_saves[sample_workflow_execution.id_] == mock_result
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_execution)
def test_wait_for_pending_saves(self, mock_session_factory, mock_account, sample_workflow_execution):
"""Test waiting for all pending save operations."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Add some mock pending saves
mock_result1 = Mock(spec=AsyncResult)
mock_result1.ready.return_value = False
mock_result2 = Mock(spec=AsyncResult)
mock_result2.ready.return_value = True
repo._pending_saves["exec1"] = mock_result1
repo._pending_saves["exec2"] = mock_result2
repo.wait_for_pending_saves(timeout=10)
# Verify that non-ready task was waited for
mock_result1.get.assert_called_once_with(timeout=10)
# Verify pending saves were cleared
assert len(repo._pending_saves) == 0
def test_get_pending_save_count(self, mock_session_factory, mock_account):
"""Test getting the count of pending save operations."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Add some mock pending saves
mock_result1 = Mock(spec=AsyncResult)
mock_result1.ready.return_value = False
mock_result2 = Mock(spec=AsyncResult)
mock_result2.ready.return_value = True
repo._pending_saves["exec1"] = mock_result1
repo._pending_saves["exec2"] = mock_result2
count = repo.get_pending_save_count()
# Should clean up completed tasks and return count of remaining
assert count == 1
assert "exec1" in repo._pending_saves
assert "exec2" not in repo._pending_saves
def test_clear_pending_saves(self, mock_session_factory, mock_account):
"""Test clearing all pending save operations."""
repo = CeleryWorkflowExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
)
# Add some mock pending saves
repo._pending_saves["exec1"] = Mock(spec=AsyncResult)
repo._pending_saves["exec2"] = Mock(spec=AsyncResult)
repo.clear_pending_saves()
assert len(repo._pending_saves) == 0

@ -0,0 +1,308 @@
"""
Unit tests for CeleryWorkflowNodeExecutionRepository.
These tests verify the Celery-based asynchronous storage functionality
for workflow node execution data.
"""
from datetime import UTC, datetime
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from celery.result import AsyncResult
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from models import Account, EndUser
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@pytest.fixture
def mock_session_factory():
"""Mock SQLAlchemy session factory."""
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Create a real sessionmaker with in-memory SQLite for testing
engine = create_engine("sqlite:///:memory:")
return sessionmaker(bind=engine)
@pytest.fixture
def mock_account():
"""Mock Account user."""
account = Mock(spec=Account)
account.id = str(uuid4())
account.current_tenant_id = str(uuid4())
return account
@pytest.fixture
def mock_end_user():
"""Mock EndUser."""
user = Mock(spec=EndUser)
user.id = str(uuid4())
user.tenant_id = str(uuid4())
return user
@pytest.fixture
def sample_workflow_node_execution():
"""Sample WorkflowNodeExecution for testing."""
return WorkflowNodeExecution(
id=str(uuid4()),
node_execution_id=str(uuid4()),
workflow_id=str(uuid4()),
workflow_execution_id=str(uuid4()),
index=1,
node_id="test_node",
node_type=NodeType.START,
title="Test Node",
inputs={"input1": "value1"},
status=WorkflowNodeExecutionStatus.RUNNING,
created_at=datetime.now(UTC).replace(tzinfo=None),
)
class TestCeleryWorkflowNodeExecutionRepository:
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
"""Test repository initialization with sessionmaker."""
app_id = "test-app-id"
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id=app_id,
triggered_from=triggered_from,
)
assert repo._tenant_id == mock_account.current_tenant_id
assert repo._app_id == app_id
assert repo._triggered_from == triggered_from
assert repo._creator_user_id == mock_account.id
assert repo._async_timeout == 30 # default timeout
def test_init_with_custom_timeout(self, mock_session_factory, mock_account):
"""Test repository initialization with custom timeout."""
custom_timeout = 60
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
async_timeout=custom_timeout,
)
assert repo._async_timeout == custom_timeout
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
"""Test repository initialization with EndUser."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_end_user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
assert repo._tenant_id == mock_end_user.tenant_id
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
"""Test that initialization fails without tenant_id."""
user = Mock()
user.current_tenant_id = None
user.tenant_id = None
with pytest.raises(ValueError, match="User must have a tenant_id"):
CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=user,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_queues_celery_task(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation queues a Celery task."""
mock_result = Mock(spec=AsyncResult)
mock_task.delay.return_value = mock_result
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
repo.save(sample_workflow_node_execution)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[1]
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
assert call_args["tenant_id"] == mock_account.current_tenant_id
assert call_args["app_id"] == "test-app"
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
assert call_args["creator_user_id"] == mock_account.id
# Verify task result is stored for tracking
assert sample_workflow_node_execution.id in repo._pending_saves
assert repo._pending_saves[sample_workflow_node_execution.id] == mock_result
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
def test_save_handles_celery_failure(
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
):
"""Test that save operation handles Celery task failures."""
mock_task.delay.side_effect = Exception("Celery is down")
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
with pytest.raises(Exception, match="Celery is down"):
repo.save(sample_workflow_node_execution)
@patch(
"core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task"
)
def test_get_by_workflow_run(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
"""Test that get_by_workflow_run retrieves all executions for a workflow run."""
executions_data = [sample_workflow_node_execution.model_dump()]
mock_result = Mock(spec=AsyncResult)
mock_result.get.return_value = executions_data
mock_task.delay.return_value = mock_result
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
order_config = OrderConfig(order_by=["index"], order_direction="asc")
result = repo.get_by_workflow_run(workflow_run_id, order_config)
# Verify Celery task was queued with correct parameters
mock_task.delay.assert_called_once_with(
workflow_run_id=workflow_run_id,
tenant_id=mock_account.current_tenant_id,
app_id="test-app",
order_config={"order_by": order_config.order_by, "order_direction": order_config.order_direction},
)
# Verify results were properly deserialized
assert len(result) == 1
assert result[0].id == sample_workflow_node_execution.id
@patch(
"core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task"
)
def test_get_by_workflow_run_without_order_config(self, mock_task, mock_session_factory, mock_account):
"""Test get_by_workflow_run without order configuration."""
mock_result = Mock(spec=AsyncResult)
mock_result.get.return_value = []
mock_task.delay.return_value = mock_result
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
result = repo.get_by_workflow_run("workflow-run-id")
# Verify order_config was passed as None
call_args = mock_task.delay.call_args[1]
assert call_args["order_config"] is None
assert len(result) == 0
def test_wait_for_pending_saves(self, mock_session_factory, mock_account):
"""Test waiting for all pending save operations."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Add some mock pending saves
mock_result1 = Mock(spec=AsyncResult)
mock_result1.ready.return_value = False
mock_result2 = Mock(spec=AsyncResult)
mock_result2.ready.return_value = True
repo._pending_saves["exec1"] = mock_result1
repo._pending_saves["exec2"] = mock_result2
repo.wait_for_pending_saves(timeout=10)
# Verify that non-ready task was waited for
mock_result1.get.assert_called_once_with(timeout=10)
# Verify pending saves were cleared
assert len(repo._pending_saves) == 0
def test_get_pending_save_count(self, mock_session_factory, mock_account):
"""Test getting the count of pending save operations."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Add some mock pending saves
mock_result1 = Mock(spec=AsyncResult)
mock_result1.ready.return_value = False
mock_result2 = Mock(spec=AsyncResult)
mock_result2.ready.return_value = True
repo._pending_saves["exec1"] = mock_result1
repo._pending_saves["exec2"] = mock_result2
count = repo.get_pending_save_count()
# Should clean up completed tasks and return count of remaining
assert count == 1
assert "exec1" in repo._pending_saves
assert "exec2" not in repo._pending_saves
def test_clear_pending_saves(self, mock_session_factory, mock_account):
"""Test clearing all pending save operations."""
repo = CeleryWorkflowNodeExecutionRepository(
session_factory=mock_session_factory,
user=mock_account,
app_id="test-app",
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Add some mock pending saves
repo._pending_saves["exec1"] = Mock(spec=AsyncResult)
repo._pending_saves["exec2"] = Mock(spec=AsyncResult)
repo.clear_pending_saves()
assert len(repo._pending_saves) == 0

@ -8,4 +8,4 @@ cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,workflow_storage

@ -811,16 +811,26 @@ WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# Repository configuration
# Core workflow execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository (default)
# - core.repositories.celery_workflow_execution_repository.CeleryWorkflowExecutionRepository
CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
# Core workflow node execution repository implementation
# Options:
# - core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository (default)
# - core.repositories.celery_workflow_node_execution_repository.CeleryWorkflowNodeExecutionRepository
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# API workflow node execution repository implementation
API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
# API workflow run repository implementation
API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
# Celery repository configuration
# Timeout in seconds for Celery repository async operations
CELERY_REPOSITORY_ASYNC_TIMEOUT=30
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760

@ -364,6 +364,7 @@ x-shared-env: &shared-api-worker-env
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository}
API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository}
CELERY_REPOSITORY_ASYNC_TIMEOUT: ${CELERY_REPOSITORY_ASYNC_TIMEOUT:-30}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}

Loading…
Cancel
Save