refactor(workflow_cycle_manager): Improve readability

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22597/head
-LAN- 10 months ago
parent c6dde2f5a3
commit 640c0625d1
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -61,23 +61,9 @@ class WorkflowCycleManager:
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = {**self._application_generate_entity.inputs}
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
# Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID:
continue
inputs[f"sys.{field_name}"] = value
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
execution_id = str(
self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None
) or str(uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
@ -88,12 +74,7 @@ class WorkflowCycleManager:
started_at=datetime.now(UTC).replace(tzinfo=None),
)
self._workflow_execution_repository.save(execution)
# Cache the execution
self._workflow_execution_cache[execution.id_] = execution
return execution
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
@ -107,23 +88,15 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(outputs)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@ -140,24 +113,17 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
# outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
self._workflow_execution_repository.save(execution)
return execution
@ -177,42 +143,18 @@ class WorkflowCycleManager:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
workflow_execution.status = WorkflowExecutionStatus(status.value)
workflow_execution.error_message = error_message
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = now
workflow_execution.exceptions_count = exceptions_count
# First check cached node executions for running status
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution.id_
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
# Update the domain models
for node_execution in running_node_executions:
if node_execution.node_execution_id:
# Update the domain model
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
# Update the repository with the domain model
self._workflow_node_execution_repository.save(node_execution)
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@ -225,70 +167,24 @@ class WorkflowCycleManager:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=created_at,
)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
# Cache the node execution
if domain_execution.node_execution_id:
self._node_execution_cache[domain_execution.node_execution_id] = domain_execution
return domain_execution
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
# Check cache first
domain_execution = self._node_execution_cache.get(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
# Update domain model
domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
@ -304,102 +200,251 @@ class WorkflowCycleManager:
:param event: queue node failed event
:return:
"""
# Check cache first
domain_execution = self._node_execution_cache.get(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = event.outputs
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
# Update domain model
domain_execution.status = (
WorkflowNodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
domain_execution.error = event.error
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
# Convert metadata keys to strings
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
return self._save_and_cache_node_execution(domain_execution)
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuid4())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID."""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: Optional[str] = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
) -> None:
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: Optional[TraceQueueManager],
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
) -> None:
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
) -> None:
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
created_at: Optional[datetime] = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = datetime.now(UTC).replace(tzinfo=None)
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
# Convert execution metadata keys to strings
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[key] = value
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
status=WorkflowNodeExecutionStatus.RETRY,
status=status,
metadata=metadata,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,
error=event.error,
index=event.node_run_index,
error=error,
)
# Update with mappings
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
# Cache the node execution
if domain_execution.node_execution_id:
self._node_execution_cache[domain_execution.node_execution_id] = domain_execution
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
handle_special_values: bool = False,
) -> None:
"""Update node execution with completion data."""
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
return domain_execution
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
outputs = event.outputs
raise WorkflowRunNotFoundError(id)
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
if error:
domain_execution.error = error
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

@ -80,15 +80,12 @@ def real_workflow_system_variables():
@pytest.fixture
def mock_node_execution_repository():
repo = MagicMock(spec=WorkflowNodeExecutionRepository)
repo.get_by_node_execution_id.return_value = None
repo.get_running_executions.return_value = []
return repo
@pytest.fixture
def mock_workflow_execution_repository():
repo = MagicMock(spec=WorkflowExecutionRepository)
repo.get.return_value = None
return repo
@ -217,8 +214,8 @@ def test_handle_workflow_run_success(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_success(
@ -251,11 +248,10 @@ def test_handle_workflow_run_failed(workflow_cycle_manager, mock_workflow_execut
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Mock get_running_executions to return an empty list
workflow_cycle_manager._workflow_node_execution_repository.get_running_executions.return_value = []
# No running node executions in cache (empty cache)
# Call the method
result = workflow_cycle_manager.handle_workflow_run_failed(
@ -289,8 +285,8 @@ def test_handle_node_execution_start(workflow_cycle_manager, mock_workflow_execu
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Create a mock event
event = MagicMock(spec=QueueNodeStartedEvent)
@ -342,8 +338,8 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository get method to return the real execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache["test-workflow-run-id"] = workflow_execution
# Call the method
result = workflow_cycle_manager._get_workflow_execution_or_raise_error("test-workflow-run-id")
@ -351,11 +347,13 @@ def test_get_workflow_execution_or_raise_error(workflow_cycle_manager, mock_work
# Verify the result
assert result == workflow_execution
# Test error case
workflow_cycle_manager._workflow_execution_repository.get.return_value = None
# Test error case - clear cache
workflow_cycle_manager._workflow_execution_cache.clear()
# Expect an error when execution is not found
with pytest.raises(ValueError):
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
with pytest.raises(WorkflowRunNotFoundError):
workflow_cycle_manager._get_workflow_execution_or_raise_error("non-existent-id")
@ -384,8 +382,8 @@ def test_handle_workflow_node_execution_success(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_success(
@ -414,8 +412,8 @@ def test_handle_workflow_run_partial_success(workflow_cycle_manager, mock_workfl
started_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock _get_workflow_execution_or_raise_error to return the real workflow_execution
workflow_cycle_manager._workflow_execution_repository.get.return_value = workflow_execution
# Pre-populate the cache with the workflow execution
workflow_cycle_manager._workflow_execution_cache[workflow_execution.id_] = workflow_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_run_partial_success(
@ -462,8 +460,8 @@ def test_handle_workflow_node_execution_failed(workflow_cycle_manager):
created_at=datetime.now(UTC).replace(tzinfo=None),
)
# Mock the repository to return the node execution
workflow_cycle_manager._workflow_node_execution_repository.get_by_node_execution_id.return_value = node_execution
# Pre-populate the cache with the node execution
workflow_cycle_manager._node_execution_cache["test-node-execution-id"] = node_execution
# Call the method
result = workflow_cycle_manager.handle_workflow_node_execution_failed(

Loading…
Cancel
Save