diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 0b3e5eb424..c579ff4028 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,6 @@ import json import logging from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Update the in-memory cache for faster subsequent lookups logger.debug(f"Updating cache for execution_id: {db_model.id}") self._execution_cache[db_model.id] = db_model - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - # First check the cache - if execution_id in self._execution_cache: - logger.debug(f"Cache hit for execution_id: {execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._execution_cache[execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.id == execution_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowRun.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._execution_cache[execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index a5feeb0d7c..d4a31390f8 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import logging from collections.abc import Sequence from typing import Optional, Union -from sqlalchemy import UnaryExpression, asc, delete, desc, select +from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") self._node_execution_cache[db_model.node_execution_id] = db_model - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - # First check the cache - if node_execution_id in self._node_execution_cache: - logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._node_execution_cache[node_execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.node_execution_id == node_execution_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._node_execution_cache[node_execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None - def get_db_models_by_workflow_run( self, workflow_run_id: str, @@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) domain_models.append(domain_model) return domain_models - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - This method queries the database directly and updates the cache with any - retrieved executions that have a node_execution_id. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, - WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_models = session.scalars(stmt).all() - domain_models = [] - - for model in db_models: - # Update cache if node_execution_id is present - if model.node_execution_id: - self._node_execution_cache[model.node_execution_id] = model - - # Convert to domain model - domain_model = self._to_domain_model(model) - domain_models.append(domain_model) - - return domain_models - - def clear(self) -> None: - """ - Clear all WorkflowNodeExecution records for the current tenant_id and app_id. - - This method deletes all WorkflowNodeExecution records that match the tenant_id - and app_id (if provided) associated with this repository instance. - It also clears the in-memory cache. - """ - with self._session_factory() as session: - stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - result = session.execute(stmt) - session.commit() - - deleted_count = result.rowcount - logger.info( - f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" - + (f" and app {self._app_id}" if self._app_id else "") - ) - - # Clear the in-memory cache - self._node_execution_cache.clear() - logger.info("Cleared in-memory node execution cache") diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py index 5917310c8b..bcbd253392 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -1,4 +1,4 @@ -from typing import Optional, Protocol +from typing import Protocol from core.workflow.entities.workflow_execution import WorkflowExecution @@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol): execution: The WorkflowExecution instance to save or update """ ... - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - ... diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index 1908a6b190..8bf81f5442 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol): """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - ... - def get_by_workflow_run( self, workflow_run_id: str, @@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol): A list of NodeExecution instances """ ... - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - ... - - def clear(self) -> None: - """ - Clear all NodeExecution records based on implementation-specific criteria. - - This method is intended to be used for bulk deletion operations, such as removing - all records associated with a specific app_id and tenant_id in multi-tenant implementations. - """ - ... diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 50ff733979..26cbd3fed6 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -55,6 +55,11 @@ class WorkflowCycleManager: self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository + # Initialize caches for workflow execution cycle + # These caches avoid redundant repository calls during a single workflow execution + self._workflow_execution_cache: dict[str, WorkflowExecution] = {} + self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} + def handle_workflow_run_start(self) -> WorkflowExecution: inputs = {**self._application_generate_entity.inputs} @@ -85,6 +90,9 @@ class WorkflowCycleManager: self._workflow_execution_repository.save(execution) + # Cache the execution + self._workflow_execution_cache[execution.id_] = execution + return execution def handle_workflow_run_success( @@ -176,10 +184,13 @@ class WorkflowCycleManager: workflow_execution.finished_at = now workflow_execution.exceptions_count = exceptions_count - # Use the instance repository to find running executions for a workflow run - running_node_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_execution.id_ - ) + # First check cached node executions for running status + running_node_executions = [ + node_exec + for node_exec in self._node_execution_cache.values() + if node_exec.workflow_execution_id == workflow_execution.id_ + and node_exec.status == WorkflowNodeExecutionStatus.RUNNING + ] # Update the domain models for node_execution in running_node_executions: @@ -240,11 +251,16 @@ class WorkflowCycleManager: # Use the instance repository to save the domain model self._workflow_node_execution_repository.save(domain_execution) + # Cache the node execution + if domain_execution.node_execution_id: + self._node_execution_cache[domain_execution.node_execution_id] = domain_execution + return domain_execution def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + # Check cache first + domain_execution = self._node_execution_cache.get(event.node_execution_id) + if not domain_execution: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") @@ -288,8 +304,9 @@ class WorkflowCycleManager: :param event: queue node failed event :return: """ - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + # Check cache first + domain_execution = self._node_execution_cache.get(event.node_execution_id) + if not domain_execution: raise ValueError(f"Domain node execution not found: {event.node_execution_id}") @@ -374,10 +391,15 @@ class WorkflowCycleManager: # Use the instance repository to save the domain model self._workflow_node_execution_repository.save(domain_execution) + # Cache the node execution + if domain_execution.node_execution_id: + self._node_execution_cache[domain_execution.node_execution_id] = domain_execution + return domain_execution def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: - execution = self._workflow_execution_repository.get(id) - if not execution: - raise WorkflowRunNotFoundError(id) - return execution + # Check cache first + if id in self._workflow_execution_cache: + return self._workflow_execution_cache[id] + + raise WorkflowRunNotFoundError(id)