Refactor workflow cycle manager to use domain model instead of DB model

pull/19430/head
-LAN- 1 year ago
parent 4deba55453
commit 9d79fdfe51
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -46,26 +46,29 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
from models import (
Account,
CreatedByRole,
EndUser,
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
@ -78,7 +81,6 @@ class WorkflowCycleManager:
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_node_execution_repository = workflow_node_execution_repository
@ -258,21 +260,22 @@ class WorkflowCycleManager:
workflow_run.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id
)
# Update the cache with the retrieved executions
for execution in running_workflow_node_executions:
if execution.node_execution_id:
self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions:
# Update the domain models
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
for domain_execution in running_domain_executions:
if domain_execution.node_execution_id:
# Update the domain model
domain_execution.status = NodeExecutionStatus.FAILED
domain_execution.error = error
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
if trace_manager:
trace_manager.add_trace_task(
@ -286,63 +289,70 @@ class WorkflowCycleManager:
return workflow_run
def _handle_node_execution_start(
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
# Create a domain model
created_at = datetime.now(UTC).replace(tzinfo=None)
metadata = {
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id,
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id,
}
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run.id,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
status=NodeExecutionStatus.RUNNING,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
metadata=metadata,
created_at=created_at,
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
return domain_execution
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution:
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
process_data = WorkflowEntry.handle_special_values(event.process_data)
# Update domain model
domain_execution.status = NodeExecutionStatus.SUCCEEDED
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
# Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
return domain_execution
def _handle_workflow_node_execution_failed(
self,
@ -351,43 +361,52 @@ class WorkflowCycleManager:
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
) -> NodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
# Get the domain model from repository
domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
# Process data
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
# Convert metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
# Update domain model
domain_execution.status = (
NodeExecutionStatus.FAILED
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
else NodeExecutionStatus.EXCEPTION
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
domain_execution.error = event.error
domain_execution.update_from_mapping(
inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
self._workflow_node_execution_repository.update(workflow_node_execution)
# Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
return workflow_node_execution
return domain_execution
def _handle_workflow_node_execution_retried(
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
) -> NodeExecution:
"""
Workflow node execution failed
:param workflow_run: workflow run
@ -399,47 +418,50 @@ class WorkflowCycleManager:
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
# Convert metadata keys to strings
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
str(NodeRunMetadataKey.ITERATION_ID): event.in_iteration_id,
str(NodeRunMetadataKey.PARALLEL_MODE_RUN_ID): event.parallel_mode_run_id,
str(NodeRunMetadataKey.LOOP_ID): event.in_loop_id,
}
merged_metadata = (
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
if event.execution_metadata is not None
else origin_metadata
# Convert execution metadata keys to strings
execution_metadata_dict = {}
if event.execution_metadata:
for key, value in event.execution_metadata.items():
execution_metadata_dict[str(key)] = value
merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
# Create a domain model
domain_execution = NodeExecution(
id=str(uuid4()),
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
workflow_run_id=workflow_run.id,
predecessor_node_id=event.predecessor_node_id,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
status=NodeExecutionStatus.RETRY,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
created_at=created_at,
finished_at=finished_at,
elapsed_time=elapsed_time,
error=event.error,
index=event.node_run_index,
)
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
# Update with mappings
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
# Use the instance repository to save the domain model
self._workflow_node_execution_repository.save(domain_execution)
return domain_execution
def _workflow_start_to_stream_response(
self,
@ -515,7 +537,7 @@ class WorkflowCycleManager:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
workflow_node_execution: NodeExecution,
) -> Optional[NodeStartStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
@ -532,7 +554,7 @@ class WorkflowCycleManager:
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
inputs=workflow_node_execution.inputs,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
@ -565,7 +587,7 @@ class WorkflowCycleManager:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
workflow_node_execution: NodeExecution,
) -> Optional[NodeFinishStreamResponse]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
@ -584,16 +606,16 @@ class WorkflowCycleManager:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
@ -608,7 +630,7 @@ class WorkflowCycleManager:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
workflow_node_execution: NodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
@ -627,16 +649,16 @@ class WorkflowCycleManager:
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
@ -908,23 +930,6 @@ class WorkflowCycleManager:
return workflow_run
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
# First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
# Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log

@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin):
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=True)
type = db.Column(db.String(255), nullable=False)
external_user_id = db.Column(db.String(255), nullable=True)

Loading…
Cancel
Save