add celery based exec repos
parent
74981a65c6
commit
2f7dc7a58a
@ -0,0 +1,180 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from tasks.workflow_execution_tasks import (
|
||||
save_workflow_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- Fallback to synchronous operations for read operations when immediate results are needed
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
- Configurable timeouts for async operations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowRunTriggeredFrom],
|
||||
async_timeout: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN)
|
||||
async_timeout: Timeout in seconds for async operations (default: 30)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# Async operation timeout
|
||||
self._async_timeout = async_timeout
|
||||
|
||||
# Cache for pending async operations
|
||||
self._pending_saves: dict[str, AsyncResult] = {}
|
||||
|
||||
logger.info(
|
||||
f"Initialized CeleryWorkflowExecutionRepository for tenant {self._tenant_id}, "
|
||||
f"app {self._app_id}, triggered_from {self._triggered_from}"
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowExecution instance asynchronously using Celery.
|
||||
|
||||
This method queues the save operation as a Celery task and returns immediately,
|
||||
providing improved performance for high-throughput scenarios.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task
|
||||
task_result = save_workflow_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
# Store the task result for potential status checking
|
||||
self._pending_saves[execution.id_] = task_result
|
||||
|
||||
logger.debug(f"Queued async save for workflow execution: {execution.id_}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to queue save operation for execution {execution.id_}: {e}")
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None:
|
||||
"""
|
||||
Wait for all pending save operations to complete.
|
||||
|
||||
This method is useful for ensuring data consistency when immediate
|
||||
persistence is required (e.g., during testing or critical operations).
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for all operations (uses instance timeout if None)
|
||||
"""
|
||||
wait_timeout = timeout or self._async_timeout
|
||||
|
||||
for execution_id, task_result in list(self._pending_saves.items()):
|
||||
try:
|
||||
if not task_result.ready():
|
||||
logger.debug(f"Waiting for save operation to complete: {execution_id}")
|
||||
task_result.get(timeout=wait_timeout)
|
||||
# Remove completed task
|
||||
del self._pending_saves[execution_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to wait for save operation {execution_id}: {e}")
|
||||
|
||||
def get_pending_save_count(self) -> int:
|
||||
"""
|
||||
Get the number of pending save operations.
|
||||
|
||||
Returns:
|
||||
Number of save operations still in progress
|
||||
"""
|
||||
# Clean up completed tasks
|
||||
completed_ids = []
|
||||
for execution_id, task_result in self._pending_saves.items():
|
||||
if task_result.ready():
|
||||
completed_ids.append(execution_id)
|
||||
|
||||
for execution_id in completed_ids:
|
||||
del self._pending_saves[execution_id]
|
||||
|
||||
return len(self._pending_saves)
|
||||
|
||||
def clear_pending_saves(self) -> None:
|
||||
"""
|
||||
Clear all pending save operations without waiting for completion.
|
||||
|
||||
This method is useful for cleanup operations or when canceling workflows.
|
||||
"""
|
||||
self._pending_saves.clear()
|
||||
logger.debug("Cleared all pending save operations")
|
||||
@ -0,0 +1,275 @@
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository.
|
||||
|
||||
This implementation uses Celery tasks for asynchronous storage operations,
|
||||
providing improved performance by offloading database operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||
from core.workflow.repositories.workflow_node_execution_repository import (
|
||||
OrderConfig,
|
||||
WorkflowNodeExecutionRepository,
|
||||
)
|
||||
from libs.helper import extract_tenant_id
|
||||
from models import Account, CreatorUserRole, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
from tasks.workflow_node_execution_tasks import (
|
||||
get_workflow_node_executions_by_workflow_run_task,
|
||||
save_workflow_node_execution_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||
"""
|
||||
Celery-based implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation provides asynchronous storage capabilities by using Celery tasks
|
||||
to handle database operations in background workers. This improves performance by
|
||||
reducing the blocking time for workflow node execution storage operations.
|
||||
|
||||
Key features:
|
||||
- Asynchronous save operations using Celery tasks
|
||||
- Fallback to synchronous operations for read operations when immediate results are needed
|
||||
- Support for multi-tenancy through tenant/app filtering
|
||||
- Automatic retry and error handling through Celery
|
||||
- Configurable timeouts for async operations
|
||||
- Batch operations for improved efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
user: Union[Account, EndUser],
|
||||
app_id: Optional[str],
|
||||
triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom],
|
||||
async_timeout: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize the repository with Celery task configuration and context information.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for fallback operations
|
||||
user: Account or EndUser object containing tenant_id, user ID, and role information
|
||||
app_id: App ID for filtering by application (can be None)
|
||||
triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN)
|
||||
async_timeout: Timeout in seconds for async operations (default: 30)
|
||||
"""
|
||||
# Store session factory for fallback operations
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_factory = session_factory
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
|
||||
)
|
||||
|
||||
# Extract tenant_id from user
|
||||
tenant_id = extract_tenant_id(user)
|
||||
if not tenant_id:
|
||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
# Store app context
|
||||
self._app_id = app_id
|
||||
|
||||
# Extract user context
|
||||
self._triggered_from = triggered_from
|
||||
self._creator_user_id = user.id
|
||||
|
||||
# Determine user role based on user type
|
||||
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
|
||||
|
||||
# Async operation timeout
|
||||
self._async_timeout = async_timeout
|
||||
|
||||
# Cache for pending async operations
|
||||
self._pending_saves: dict[str, AsyncResult] = {}
|
||||
|
||||
# Cache for mapping execution IDs to workflow_execution_ids for efficient workflow-specific waiting
|
||||
self._workflow_execution_mapping: dict[str, str] = {}
|
||||
|
||||
logger.info(
|
||||
f"Initialized CeleryWorkflowNodeExecutionRepository for tenant {self._tenant_id}, "
|
||||
f"app {self._app_id}, triggered_from {self._triggered_from}"
|
||||
)
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save or update a WorkflowNodeExecution instance asynchronously using Celery.
|
||||
|
||||
This method queues the save operation as a Celery task and returns immediately,
|
||||
providing improved performance for high-throughput scenarios.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save or update
|
||||
"""
|
||||
try:
|
||||
# Serialize execution for Celery task
|
||||
execution_data = execution.model_dump()
|
||||
|
||||
# Queue the save operation as a Celery task
|
||||
task_result = save_workflow_node_execution_task.delay(
|
||||
execution_data=execution_data,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
triggered_from=self._triggered_from.value if self._triggered_from else "",
|
||||
creator_user_id=self._creator_user_id,
|
||||
creator_user_role=self._creator_user_role.value,
|
||||
)
|
||||
|
||||
# Store the task result for potential status checking
|
||||
self._pending_saves[execution.id] = task_result
|
||||
|
||||
# Cache the workflow_execution_id mapping for efficient workflow-specific waiting
|
||||
self._workflow_execution_mapping[execution.id] = execution.workflow_execution_id
|
||||
|
||||
logger.debug(f"Queued async save for workflow node execution: {execution.id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to queue save operation for node execution {execution.id}: {e}")
|
||||
# In case of Celery failure, we could implement a fallback to synchronous save
|
||||
# For now, we'll re-raise the exception
|
||||
raise
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecution instances
|
||||
"""
|
||||
try:
|
||||
# Wait for any pending saves that might affect this workflow run
|
||||
self._wait_for_pending_saves_by_workflow_run(workflow_run_id)
|
||||
|
||||
# Serialize order config for Celery task
|
||||
if order_config:
|
||||
order_config_data = {"order_by": order_config.order_by, "order_direction": order_config.order_direction}
|
||||
else:
|
||||
order_config_data = None
|
||||
|
||||
# Queue the get operation as a Celery task
|
||||
task_result = get_workflow_node_executions_by_workflow_run_task.delay(
|
||||
workflow_run_id=workflow_run_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id or "",
|
||||
order_config=order_config_data,
|
||||
)
|
||||
|
||||
# Wait for the result (synchronous for read operations)
|
||||
executions_data = task_result.get(timeout=self._async_timeout)
|
||||
|
||||
result = []
|
||||
for execution_data in executions_data:
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
result.append(execution)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}")
|
||||
# Could implement fallback to direct database access here
|
||||
return []
|
||||
|
||||
def _wait_for_pending_saves_by_workflow_run(self, workflow_run_id: str) -> None:
|
||||
"""
|
||||
Wait for any pending save operations that might affect the given workflow run.
|
||||
|
||||
This method now uses the cached workflow_execution_id mapping to only wait for
|
||||
tasks that belong to the specific workflow run, improving efficiency.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID to check
|
||||
"""
|
||||
# Find execution IDs that belong to this workflow run
|
||||
relevant_execution_ids = [
|
||||
execution_id for execution_id, cached_workflow_id in self._workflow_execution_mapping.items()
|
||||
if cached_workflow_id == workflow_run_id and execution_id in self._pending_saves
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(relevant_execution_ids)} pending saves for workflow run {workflow_run_id}")
|
||||
|
||||
for execution_id in relevant_execution_ids:
|
||||
task_result = self._pending_saves.get(execution_id)
|
||||
if task_result and not task_result.ready():
|
||||
try:
|
||||
logger.debug(f"Waiting for pending save to complete before read: {execution_id}")
|
||||
task_result.get(timeout=self._async_timeout)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to wait for pending save {execution_id}: {e}")
|
||||
|
||||
# Clean up completed tasks from both caches
|
||||
if task_result and task_result.ready():
|
||||
self._pending_saves.pop(execution_id, None)
|
||||
self._workflow_execution_mapping.pop(execution_id, None)
|
||||
|
||||
def wait_for_pending_saves(self, timeout: Optional[int] = None) -> None:
|
||||
"""
|
||||
Wait for all pending save operations to complete.
|
||||
|
||||
This method is useful for ensuring data consistency when immediate
|
||||
persistence is required (e.g., during testing or critical operations).
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for all operations (uses instance timeout if None)
|
||||
"""
|
||||
wait_timeout = timeout or self._async_timeout
|
||||
|
||||
for execution_id, task_result in list(self._pending_saves.items()):
|
||||
try:
|
||||
if not task_result.ready():
|
||||
logger.debug(f"Waiting for save operation to complete: {execution_id}")
|
||||
task_result.get(timeout=wait_timeout)
|
||||
# Remove completed task from both caches
|
||||
del self._pending_saves[execution_id]
|
||||
self._workflow_execution_mapping.pop(execution_id, None)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to wait for save operation {execution_id}: {e}")
|
||||
|
||||
def get_pending_save_count(self) -> int:
|
||||
"""
|
||||
Get the number of pending save operations.
|
||||
|
||||
Returns:
|
||||
Number of save operations still in progress
|
||||
"""
|
||||
# Clean up completed tasks
|
||||
completed_ids = []
|
||||
for execution_id, task_result in self._pending_saves.items():
|
||||
if task_result.ready():
|
||||
completed_ids.append(execution_id)
|
||||
|
||||
for execution_id in completed_ids:
|
||||
del self._pending_saves[execution_id]
|
||||
self._workflow_execution_mapping.pop(execution_id, None)
|
||||
|
||||
return len(self._pending_saves)
|
||||
|
||||
def clear_pending_saves(self) -> None:
|
||||
"""
|
||||
Clear all pending save operations without waiting for completion.
|
||||
|
||||
This method is useful for cleanup operations or when canceling workflows.
|
||||
"""
|
||||
self._pending_saves.clear()
|
||||
self._workflow_execution_mapping.clear()
|
||||
logger.debug("Cleared all pending save operations and workflow execution mappings")
|
||||
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowExecution.model_validate(execution_data)
|
||||
|
||||
# Check if workflow run already exists
|
||||
existing_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == execution.id_))
|
||||
|
||||
if existing_run:
|
||||
# Update existing workflow run
|
||||
_update_workflow_run_from_execution(existing_run, execution)
|
||||
logger.debug(f"Updated existing workflow run: {execution.id_}")
|
||||
else:
|
||||
# Create new workflow run
|
||||
workflow_run = _create_workflow_run_from_execution(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(workflow_run)
|
||||
logger.debug(f"Created new workflow run: {execution.id_}")
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to save workflow execution {execution_data.get('id_', 'unknown')}: {e}")
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_workflow_run_from_execution(
|
||||
execution: WorkflowExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Create a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = execution.id_
|
||||
workflow_run.tenant_id = tenant_id
|
||||
workflow_run.app_id = app_id
|
||||
workflow_run.workflow_id = execution.workflow_id
|
||||
workflow_run.type = execution.workflow_type.value
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = execution.workflow_version
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph))
|
||||
workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs))
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.created_by_role = creator_user_role.value
|
||||
workflow_run.created_by = creator_user_id
|
||||
workflow_run.created_at = execution.started_at
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
|
||||
"""
|
||||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
workflow_run.status = execution.status.value
|
||||
workflow_run.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
workflow_run.error = execution.error_message
|
||||
workflow_run.elapsed_time = execution.elapsed_time
|
||||
workflow_run.total_tokens = execution.total_tokens
|
||||
workflow_run.total_steps = execution.total_steps
|
||||
workflow_run.finished_at = execution.finished_at
|
||||
|
||||
|
||||
def _create_execution_from_workflow_run(workflow_run: WorkflowRun) -> WorkflowExecution:
|
||||
"""
|
||||
Create a WorkflowExecution domain entity from a WorkflowRun database model.
|
||||
"""
|
||||
return WorkflowExecution(
|
||||
id_=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_type=WorkflowType(workflow_run.type),
|
||||
workflow_version=workflow_run.version,
|
||||
graph=json.loads(workflow_run.graph or "{}"),
|
||||
inputs=json.loads(workflow_run.inputs or "{}"),
|
||||
outputs=json.loads(workflow_run.outputs or "{}"),
|
||||
status=WorkflowExecutionStatus(workflow_run.status),
|
||||
error_message=workflow_run.error or "",
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
started_at=workflow_run.created_at,
|
||||
finished_at=workflow_run.finished_at,
|
||||
)
|
||||
@ -0,0 +1,277 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow node execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow node execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from models import CreatorUserRole, WorkflowNodeExecutionModel
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_node_execution_task(
|
||||
self,
|
||||
execution_data: dict,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
creator_user_id: str,
|
||||
creator_user_role: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Asynchronously save or update a workflow node execution to the database.
|
||||
|
||||
Args:
|
||||
execution_data: Serialized WorkflowNodeExecution data
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
creator_user_id: ID of the user who created the execution
|
||||
creator_user_role: Role of the user who created the execution
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Create a new session for this task
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Deserialize execution data
|
||||
execution = WorkflowNodeExecution.model_validate(execution_data)
|
||||
|
||||
# Check if node execution already exists
|
||||
existing_execution = session.scalar(
|
||||
select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution.id)
|
||||
)
|
||||
|
||||
if existing_execution:
|
||||
# Update existing node execution
|
||||
_update_node_execution_from_domain(existing_execution, execution)
|
||||
logger.debug(f"Updated existing workflow node execution: {execution.id}")
|
||||
else:
|
||||
# Create new node execution
|
||||
node_execution = _create_node_execution_from_domain(
|
||||
execution=execution,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom(triggered_from),
|
||||
creator_user_id=creator_user_id,
|
||||
creator_user_role=CreatorUserRole(creator_user_role),
|
||||
)
|
||||
session.add(node_execution)
|
||||
logger.debug(f"Created new workflow node execution: {execution.id}")
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to save workflow node execution {execution_data.get('id', 'unknown')}: {e}")
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
@shared_task(queue="workflow_storage", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def get_workflow_node_executions_by_workflow_run_task(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
order_config: Optional[dict] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Asynchronously retrieve all workflow node executions for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Application ID
|
||||
order_config: Optional ordering configuration
|
||||
|
||||
Returns:
|
||||
List of serialized WorkflowNodeExecution data
|
||||
"""
|
||||
try:
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
# Build base query
|
||||
query = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
)
|
||||
|
||||
# Apply ordering if specified
|
||||
if order_config:
|
||||
order_obj = OrderConfig(
|
||||
order_by=order_config["order_by"], order_direction=order_config.get("order_direction")
|
||||
)
|
||||
for field_name in order_obj.order_by:
|
||||
field = getattr(WorkflowNodeExecutionModel, field_name, None)
|
||||
if field is not None:
|
||||
if order_obj.order_direction == "desc":
|
||||
query = query.order_by(field.desc())
|
||||
else:
|
||||
query = query.order_by(field.asc())
|
||||
|
||||
node_executions = session.scalars(query).all()
|
||||
|
||||
result = []
|
||||
for node_execution in node_executions:
|
||||
execution = _create_domain_from_node_execution(node_execution)
|
||||
result.append(execution.model_dump())
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get workflow node executions for run {workflow_run_id}: {e}")
|
||||
# Retry the task with exponential backoff
|
||||
raise self.retry(exc=e, countdown=60 * (2**self.request.retries))
|
||||
|
||||
|
||||
def _create_node_execution_from_domain(
|
||||
execution: WorkflowNodeExecution,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
creator_user_id: str,
|
||||
creator_user_role: CreatorUserRole,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
node_execution = WorkflowNodeExecutionModel()
|
||||
node_execution.id = execution.id
|
||||
node_execution.tenant_id = tenant_id
|
||||
node_execution.app_id = app_id
|
||||
node_execution.workflow_id = execution.workflow_id
|
||||
node_execution.triggered_from = triggered_from.value
|
||||
node_execution.workflow_run_id = execution.workflow_execution_id
|
||||
node_execution.index = execution.index
|
||||
node_execution.predecessor_node_id = execution.predecessor_node_id
|
||||
node_execution.node_id = execution.node_id
|
||||
node_execution.node_type = execution.node_type.value
|
||||
node_execution.title = execution.title
|
||||
node_execution.node_execution_id = execution.node_execution_id
|
||||
|
||||
# Serialize complex data as JSON
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.created_by_role = creator_user_role.value
|
||||
node_execution.created_by = creator_user_id
|
||||
node_execution.created_at = execution.created_at
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(
|
||||
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
|
||||
) -> None:
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
# Update serialized data
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
node_execution.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) if execution.inputs else "{}"
|
||||
node_execution.process_data = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.process_data)) if execution.process_data else "{}"
|
||||
)
|
||||
node_execution.outputs = (
|
||||
json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}"
|
||||
)
|
||||
# Convert metadata enum keys to strings for JSON serialization
|
||||
if execution.metadata:
|
||||
metadata_for_json = {
|
||||
key.value if hasattr(key, "value") else str(key): value for key, value in execution.metadata.items()
|
||||
}
|
||||
node_execution.execution_metadata = json.dumps(json_converter.to_json_encodable(metadata_for_json))
|
||||
else:
|
||||
node_execution.execution_metadata = "{}"
|
||||
|
||||
# Update other fields
|
||||
node_execution.status = execution.status.value
|
||||
node_execution.error = execution.error
|
||||
node_execution.elapsed_time = execution.elapsed_time
|
||||
node_execution.finished_at = execution.finished_at
|
||||
|
||||
|
||||
def _create_domain_from_node_execution(node_execution: WorkflowNodeExecutionModel) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Create a WorkflowNodeExecution domain entity from a WorkflowNodeExecutionModel database model.
|
||||
"""
|
||||
# Deserialize JSON data
|
||||
inputs = json.loads(node_execution.inputs or "{}")
|
||||
process_data = json.loads(node_execution.process_data or "{}")
|
||||
outputs = json.loads(node_execution.outputs or "{}")
|
||||
metadata = json.loads(node_execution.execution_metadata or "{}")
|
||||
|
||||
# Convert metadata keys to enum values
|
||||
typed_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
try:
|
||||
enum_key = WorkflowNodeExecutionMetadataKey(key)
|
||||
typed_metadata[enum_key] = value
|
||||
except ValueError:
|
||||
# Skip unknown metadata keys
|
||||
continue
|
||||
|
||||
return WorkflowNodeExecution(
|
||||
id=node_execution.id,
|
||||
node_execution_id=node_execution.node_execution_id,
|
||||
workflow_id=node_execution.workflow_id,
|
||||
workflow_execution_id=node_execution.workflow_run_id,
|
||||
index=node_execution.index,
|
||||
predecessor_node_id=node_execution.predecessor_node_id,
|
||||
node_id=node_execution.node_id,
|
||||
node_type=NodeType(node_execution.node_type),
|
||||
title=node_execution.title,
|
||||
inputs=inputs if inputs else None,
|
||||
process_data=process_data if process_data else None,
|
||||
outputs=outputs if outputs else None,
|
||||
status=WorkflowNodeExecutionStatus(node_execution.status),
|
||||
error=node_execution.error,
|
||||
elapsed_time=node_execution.elapsed_time,
|
||||
metadata=typed_metadata if typed_metadata else None,
|
||||
created_at=node_execution.created_at,
|
||||
finished_at=node_execution.finished_at,
|
||||
)
|
||||
@ -0,0 +1,237 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_execution():
|
||||
"""Sample WorkflowExecution for testing."""
|
||||
return WorkflowExecution.new(
|
||||
id_=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"input1": "value1"},
|
||||
started_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._async_timeout == 30 # default timeout
|
||||
|
||||
def test_init_with_custom_timeout(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with custom timeout."""
|
||||
custom_timeout = 60
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
async_timeout=custom_timeout,
|
||||
)
|
||||
|
||||
assert repo._async_timeout == custom_timeout
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
user = Mock()
|
||||
user.current_tenant_id = None
|
||||
user.tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_queues_celery_task(self, mock_task, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test that save operation queues a Celery task."""
|
||||
mock_result = Mock(spec=AsyncResult)
|
||||
mock_task.delay.return_value = mock_result
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowRunTriggeredFrom.APP_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify task result is stored for tracking
|
||||
assert sample_workflow_execution.id_ in repo._pending_saves
|
||||
assert repo._pending_saves[sample_workflow_execution.id_] == mock_result
|
||||
|
||||
@patch("core.repositories.celery_workflow_execution_repository.save_workflow_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_execution)
|
||||
|
||||
|
||||
def test_wait_for_pending_saves(self, mock_session_factory, mock_account, sample_workflow_execution):
|
||||
"""Test waiting for all pending save operations."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
mock_result1 = Mock(spec=AsyncResult)
|
||||
mock_result1.ready.return_value = False
|
||||
mock_result2 = Mock(spec=AsyncResult)
|
||||
mock_result2.ready.return_value = True
|
||||
|
||||
repo._pending_saves["exec1"] = mock_result1
|
||||
repo._pending_saves["exec2"] = mock_result2
|
||||
|
||||
repo.wait_for_pending_saves(timeout=10)
|
||||
|
||||
# Verify that non-ready task was waited for
|
||||
mock_result1.get.assert_called_once_with(timeout=10)
|
||||
|
||||
# Verify pending saves were cleared
|
||||
assert len(repo._pending_saves) == 0
|
||||
|
||||
def test_get_pending_save_count(self, mock_session_factory, mock_account):
|
||||
"""Test getting the count of pending save operations."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
mock_result1 = Mock(spec=AsyncResult)
|
||||
mock_result1.ready.return_value = False
|
||||
mock_result2 = Mock(spec=AsyncResult)
|
||||
mock_result2.ready.return_value = True
|
||||
|
||||
repo._pending_saves["exec1"] = mock_result1
|
||||
repo._pending_saves["exec2"] = mock_result2
|
||||
|
||||
count = repo.get_pending_save_count()
|
||||
|
||||
# Should clean up completed tasks and return count of remaining
|
||||
assert count == 1
|
||||
assert "exec1" in repo._pending_saves
|
||||
assert "exec2" not in repo._pending_saves
|
||||
|
||||
def test_clear_pending_saves(self, mock_session_factory, mock_account):
|
||||
"""Test clearing all pending save operations."""
|
||||
repo = CeleryWorkflowExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
repo._pending_saves["exec1"] = Mock(spec=AsyncResult)
|
||||
repo._pending_saves["exec2"] = Mock(spec=AsyncResult)
|
||||
|
||||
repo.clear_pending_saves()
|
||||
|
||||
assert len(repo._pending_saves) == 0
|
||||
@ -0,0 +1,308 @@
|
||||
"""
|
||||
Unit tests for CeleryWorkflowNodeExecutionRepository.
|
||||
|
||||
These tests verify the Celery-based asynchronous storage functionality
|
||||
for workflow node execution data.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from celery.result import AsyncResult
|
||||
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
|
||||
from models import Account, EndUser
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
"""Mock SQLAlchemy session factory."""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Create a real sessionmaker with in-memory SQLite for testing
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account():
|
||||
"""Mock Account user."""
|
||||
account = Mock(spec=Account)
|
||||
account.id = str(uuid4())
|
||||
account.current_tenant_id = str(uuid4())
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_end_user():
|
||||
"""Mock EndUser."""
|
||||
user = Mock(spec=EndUser)
|
||||
user.id = str(uuid4())
|
||||
user.tenant_id = str(uuid4())
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_workflow_node_execution():
|
||||
"""Sample WorkflowNodeExecution for testing."""
|
||||
return WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
node_execution_id=str(uuid4()),
|
||||
workflow_id=str(uuid4()),
|
||||
workflow_execution_id=str(uuid4()),
|
||||
index=1,
|
||||
node_id="test_node",
|
||||
node_type=NodeType.START,
|
||||
title="Test Node",
|
||||
inputs={"input1": "value1"},
|
||||
status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
created_at=datetime.now(UTC).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
|
||||
class TestCeleryWorkflowNodeExecutionRepository:
|
||||
"""Test cases for CeleryWorkflowNodeExecutionRepository."""
|
||||
|
||||
def test_init_with_sessionmaker(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with sessionmaker."""
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_account.current_tenant_id
|
||||
assert repo._app_id == app_id
|
||||
assert repo._triggered_from == triggered_from
|
||||
assert repo._creator_user_id == mock_account.id
|
||||
assert repo._async_timeout == 30 # default timeout
|
||||
|
||||
def test_init_with_custom_timeout(self, mock_session_factory, mock_account):
|
||||
"""Test repository initialization with custom timeout."""
|
||||
custom_timeout = 60
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
async_timeout=custom_timeout,
|
||||
)
|
||||
|
||||
assert repo._async_timeout == custom_timeout
|
||||
|
||||
def test_init_with_end_user(self, mock_session_factory, mock_end_user):
|
||||
"""Test repository initialization with EndUser."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_end_user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
assert repo._tenant_id == mock_end_user.tenant_id
|
||||
|
||||
def test_init_without_tenant_id_raises_error(self, mock_session_factory):
|
||||
"""Test that initialization fails without tenant_id."""
|
||||
user = Mock()
|
||||
user.current_tenant_id = None
|
||||
user.tenant_id = None
|
||||
|
||||
with pytest.raises(ValueError, match="User must have a tenant_id"):
|
||||
CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=user,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_queues_celery_task(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation queues a Celery task."""
|
||||
mock_result = Mock(spec=AsyncResult)
|
||||
mock_task.delay.return_value = mock_result
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once()
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
|
||||
assert call_args["execution_data"] == sample_workflow_node_execution.model_dump()
|
||||
assert call_args["tenant_id"] == mock_account.current_tenant_id
|
||||
assert call_args["app_id"] == "test-app"
|
||||
assert call_args["triggered_from"] == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
assert call_args["creator_user_id"] == mock_account.id
|
||||
|
||||
# Verify task result is stored for tracking
|
||||
assert sample_workflow_node_execution.id in repo._pending_saves
|
||||
assert repo._pending_saves[sample_workflow_node_execution.id] == mock_result
|
||||
|
||||
@patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task")
|
||||
def test_save_handles_celery_failure(
|
||||
self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution
|
||||
):
|
||||
"""Test that save operation handles Celery task failures."""
|
||||
mock_task.delay.side_effect = Exception("Celery is down")
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="Celery is down"):
|
||||
repo.save(sample_workflow_node_execution)
|
||||
|
||||
|
||||
@patch(
|
||||
"core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task"
|
||||
)
|
||||
def test_get_by_workflow_run(self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution):
|
||||
"""Test that get_by_workflow_run retrieves all executions for a workflow run."""
|
||||
executions_data = [sample_workflow_node_execution.model_dump()]
|
||||
mock_result = Mock(spec=AsyncResult)
|
||||
mock_result.get.return_value = executions_data
|
||||
mock_task.delay.return_value = mock_result
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
workflow_run_id = sample_workflow_node_execution.workflow_execution_id
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="asc")
|
||||
|
||||
result = repo.get_by_workflow_run(workflow_run_id, order_config)
|
||||
|
||||
# Verify Celery task was queued with correct parameters
|
||||
mock_task.delay.assert_called_once_with(
|
||||
workflow_run_id=workflow_run_id,
|
||||
tenant_id=mock_account.current_tenant_id,
|
||||
app_id="test-app",
|
||||
order_config={"order_by": order_config.order_by, "order_direction": order_config.order_direction},
|
||||
)
|
||||
|
||||
# Verify results were properly deserialized
|
||||
assert len(result) == 1
|
||||
assert result[0].id == sample_workflow_node_execution.id
|
||||
|
||||
@patch(
|
||||
"core.repositories.celery_workflow_node_execution_repository.get_workflow_node_executions_by_workflow_run_task"
|
||||
)
|
||||
def test_get_by_workflow_run_without_order_config(self, mock_task, mock_session_factory, mock_account):
|
||||
"""Test get_by_workflow_run without order configuration."""
|
||||
mock_result = Mock(spec=AsyncResult)
|
||||
mock_result.get.return_value = []
|
||||
mock_task.delay.return_value = mock_result
|
||||
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
result = repo.get_by_workflow_run("workflow-run-id")
|
||||
|
||||
# Verify order_config was passed as None
|
||||
call_args = mock_task.delay.call_args[1]
|
||||
assert call_args["order_config"] is None
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_wait_for_pending_saves(self, mock_session_factory, mock_account):
|
||||
"""Test waiting for all pending save operations."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
mock_result1 = Mock(spec=AsyncResult)
|
||||
mock_result1.ready.return_value = False
|
||||
mock_result2 = Mock(spec=AsyncResult)
|
||||
mock_result2.ready.return_value = True
|
||||
|
||||
repo._pending_saves["exec1"] = mock_result1
|
||||
repo._pending_saves["exec2"] = mock_result2
|
||||
|
||||
repo.wait_for_pending_saves(timeout=10)
|
||||
|
||||
# Verify that non-ready task was waited for
|
||||
mock_result1.get.assert_called_once_with(timeout=10)
|
||||
|
||||
# Verify pending saves were cleared
|
||||
assert len(repo._pending_saves) == 0
|
||||
|
||||
def test_get_pending_save_count(self, mock_session_factory, mock_account):
|
||||
"""Test getting the count of pending save operations."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
mock_result1 = Mock(spec=AsyncResult)
|
||||
mock_result1.ready.return_value = False
|
||||
mock_result2 = Mock(spec=AsyncResult)
|
||||
mock_result2.ready.return_value = True
|
||||
|
||||
repo._pending_saves["exec1"] = mock_result1
|
||||
repo._pending_saves["exec2"] = mock_result2
|
||||
|
||||
count = repo.get_pending_save_count()
|
||||
|
||||
# Should clean up completed tasks and return count of remaining
|
||||
assert count == 1
|
||||
assert "exec1" in repo._pending_saves
|
||||
assert "exec2" not in repo._pending_saves
|
||||
|
||||
def test_clear_pending_saves(self, mock_session_factory, mock_account):
|
||||
"""Test clearing all pending save operations."""
|
||||
repo = CeleryWorkflowNodeExecutionRepository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_account,
|
||||
app_id="test-app",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
# Add some mock pending saves
|
||||
repo._pending_saves["exec1"] = Mock(spec=AsyncResult)
|
||||
repo._pending_saves["exec2"] = Mock(spec=AsyncResult)
|
||||
|
||||
repo.clear_pending_saves()
|
||||
|
||||
assert len(repo._pending_saves) == 0
|
||||
|
||||
Loading…
Reference in New Issue