feat(workflow_cycle_manager): Removes redundant repository methods and adds caching

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22597/head
-LAN- 10 months ago
parent 62586719b3
commit c6dde2f5a3
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -6,7 +6,6 @@ import json
import logging
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
# First check the cache
if execution_id in self._execution_cache:
logger.debug(f"Cache hit for execution_id: {execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._execution_cache[execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowRun).where(
WorkflowRun.id == execution_id,
WorkflowRun.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowRun.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._execution_cache[execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None

@ -7,7 +7,7 @@ import logging
from collections.abc import Sequence
from typing import Optional, Union
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
First checks the in-memory cache, and if not found, queries the database.
If found in the database, adds it to the cache for future lookups.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
# First check the cache
if node_execution_id in self._node_execution_cache:
logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
# Convert cached DB model to domain model
cached_db_model = self._node_execution_cache[node_execution_id]
return self._to_domain_model(cached_db_model)
# If not in cache, query the database
logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_model = session.scalar(stmt)
if db_model:
# Add DB model to cache
self._node_execution_cache[node_execution_id] = db_model
# Convert to domain model and return
return self._to_domain_model(db_model)
return None
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
domain_models.append(domain_model)
return domain_models
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
This method queries the database directly and updates the cache with any
retrieved executions that have a node_execution_id.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
db_models = session.scalars(stmt).all()
domain_models = []
for model in db_models:
# Update cache if node_execution_id is present
if model.node_execution_id:
self._node_execution_cache[model.node_execution_id] = model
# Convert to domain model
domain_model = self._to_domain_model(model)
domain_models.append(domain_model)
return domain_models
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
It also clears the in-memory cache.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
result = session.execute(stmt)
session.commit()
deleted_count = result.rowcount
logger.info(
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)
# Clear the in-memory cache
self._node_execution_cache.clear()
logger.info("Cleared in-memory node execution cache")

@ -1,4 +1,4 @@
from typing import Optional, Protocol
from typing import Protocol
from core.workflow.entities.workflow_execution import WorkflowExecution
@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
execution: The WorkflowExecution instance to save or update
"""
...
def get(self, execution_id: str) -> Optional[WorkflowExecution]:
"""
Retrieve a WorkflowExecution by its ID.
Args:
execution_id: The workflow execution ID
Returns:
The WorkflowExecution instance if found, None otherwise
"""
...

@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a NodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The NodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
A list of NodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running NodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running NodeExecution instances
"""
...
def clear(self) -> None:
"""
Clear all NodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

@ -55,6 +55,11 @@ class WorkflowCycleManager:
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
@ -85,6 +90,9 @@ class WorkflowCycleManager:
self._workflow_execution_repository.save(execution)
# Cache the execution
self._workflow_execution_cache[execution.id_] = execution
return execution
def handle_workflow_run_success(
@ -176,10 +184,13 @@ class WorkflowCycleManager:
workflow_execution.finished_at = now
workflow_execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_execution.id_
)
# First check cached node executions for running status
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution.id_
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
# Update the domain models
for node_execution in running_node_executions:
@ -240,11 +251,16 @@ class WorkflowCycleManager:
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
# Cache the node execution
if domain_execution.node_execution_id:
self._node_execution_cache[domain_execution.node_execution_id] = domain_execution
return domain_execution
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
# Check cache first
domain_execution = self._node_execution_cache.get(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
@ -288,8 +304,9 @@ class WorkflowCycleManager:
:param event: queue node failed event
:return:
"""
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
# Check cache first
domain_execution = self._node_execution_cache.get(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
@ -374,10 +391,15 @@ class WorkflowCycleManager:
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
# Cache the node execution
if domain_execution.node_execution_id:
self._node_execution_cache[domain_execution.node_execution_id] = domain_execution
return domain_execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
execution = self._workflow_execution_repository.get(id)
if not execution:
raise WorkflowRunNotFoundError(id)
return execution
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)

Loading…
Cancel
Save