refactor(workflow_cycle_manager): Refactors success and failed handles to use workflow_execution

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20067/head
-LAN- 1 year ago
parent 0fdef95645
commit 485fbc0336
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -536,10 +536,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -548,9 +546,10 @@ class AdvancedChatAppGenerateTaskPipeline:
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -563,10 +562,8 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -575,9 +572,10 @@ class AdvancedChatAppGenerateTaskPipeline:
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
yield workflow_finish_resp
self._base_task_pipeline._queue_manager.publish(
@ -590,26 +588,25 @@ class AdvancedChatAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err = self._base_task_pipeline._handle_error(
event=err_event, session=session, message_id=self._message_id
)
session.commit()
yield workflow_finish_resp
yield self._base_task_pipeline._error_to_stream_response(err)
@ -617,21 +614,19 @@ class AdvancedChatAppGenerateTaskPipeline:
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)

@ -212,7 +212,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
workflow_id: str
sequence_number: int
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int
@ -788,7 +788,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: str
outputs: Optional[dict] = None
outputs: Optional[Mapping[str, Any]] = None
error: Optional[str] = None
elapsed_time: float
total_tokens: int

@ -30,6 +30,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -373,7 +374,7 @@ class TraceTask:
self,
trace_type: Any,
message_id: Optional[str] = None,
workflow_run: Optional[WorkflowRun] = None,
workflow_execution: Optional[WorkflowExecution] = None,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
timer: Optional[Any] = None,
@ -381,7 +382,7 @@ class TraceTask:
):
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_run.id if workflow_run else None
self.workflow_run_id = workflow_execution.id if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer

@ -3,6 +3,7 @@ import time
from collections.abc import Generator
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@ -53,6 +54,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -492,10 +494,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_success(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -504,12 +504,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
)
session.commit()
@ -521,10 +521,8 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
@ -534,10 +532,12 @@ class WorkflowAppGenerateTaskPipeline:
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -549,26 +549,28 @@ class WorkflowAppGenerateTaskPipeline:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed(
session=session,
workflow_execution = self._workflow_cycle_manager._handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
session.commit()
@ -596,11 +598,9 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
"""
Save workflow app log.
:return:
"""
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API

@ -1,4 +1,3 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
@ -54,7 +53,7 @@ from core.workflow.entities.node_execution_entities import (
NodeExecution,
NodeExecutionStatus,
)
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
@ -135,121 +134,92 @@ class WorkflowCycleManager:
def _handle_workflow_run_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
workflow_execution.outputs = outputs or {}
workflow_execution.total_tokens = total_tokens
workflow_execution.total_steps = total_steps
workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return workflow_execution
def _handle_workflow_run_partial_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
execution = self._get_workflow_execution(workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return execution
def _handle_workflow_run_failed(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
error_message: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run_id: workflow run id
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
) -> WorkflowExecution:
execution = self._get_workflow_execution(workflow_run_id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
execution.status = WorkflowExecutionStatus(status.value)
execution.error_message = error_message
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
execution.exceptions_count = exceptions_count
# Use the instance repository to find running executions for a workflow run
running_domain_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id
workflow_run_id=execution.id
)
# Update the domain models
@ -258,7 +228,7 @@ class WorkflowCycleManager:
if domain_execution.node_execution_id:
# Update the domain model
domain_execution.status = NodeExecutionStatus.FAILED
domain_execution.error = error
domain_execution.error = error_message
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds()
@ -269,13 +239,13 @@ class WorkflowCycleManager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
return execution
def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution:
# Create a domain model
@ -468,9 +438,11 @@ class WorkflowCycleManager:
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
workflow_execution: WorkflowExecution,
) -> WorkflowFinishStreamResponse:
created_by = None
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
assert workflow_run is not None
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
@ -493,29 +465,29 @@ class WorkflowCycleManager:
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_run.finished_at.timestamp())
if workflow_run.finished_at
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=workflow_execution.id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
id=workflow_execution.id,
workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_execution.sequence_number,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
files=self._fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
),
)
@ -847,7 +819,7 @@ class WorkflowCycleManager:
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict

Loading…
Cancel
Save