refactor(workflow_cycle_manager): Refactors `_handle_workflow_run_start` to use WorkflowExecution

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20067/head
-LAN- 1 year ago
parent e7a6942971
commit 5d3cea5b95
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -297,21 +297,19 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( workflow_execution = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_execution.id
message = self._get_message(session=session) message = self._get_message(session=session)
if not message: if not message:
raise ValueError(f"Message not found: {self._message_id}") raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id message.workflow_run_id = workflow_execution.id
workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
) )
session.commit()
yield workflow_start_resp yield workflow_start_resp
elif isinstance( elif isinstance(

@ -190,7 +190,7 @@ class WorkflowStartStreamResponse(StreamResponse):
id: str id: str
workflow_id: str workflow_id: str
sequence_number: int sequence_number: int
inputs: dict inputs: Mapping[str, Any]
created_at: int created_at: int
event: StreamEvent = StreamEvent.WORKFLOW_STARTED event: StreamEvent = StreamEvent.WORKFLOW_STARTED

@ -261,17 +261,15 @@ class WorkflowAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# init workflow run # init workflow run
workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( workflow_execution = self._workflow_cycle_manager._handle_workflow_run_start(
session=session, session=session,
workflow_id=self._workflow_id, workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
) )
self._workflow_run_id = workflow_run.id self._workflow_run_id = workflow_execution.id
start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
) )
session.commit()
yield start_resp yield start_resp
elif isinstance( elif isinstance(

@ -8,7 +8,7 @@ from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAgentLogEvent, QueueAgentLogEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
@ -54,6 +54,7 @@ from core.workflow.entities.node_execution_entities import (
NodeExecution, NodeExecution,
NodeExecutionStatus, NodeExecutionStatus,
) )
from core.workflow.entities.workflow_execution_entities import WorkflowExecution, WorkflowType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
@ -68,7 +69,6 @@ from models import (
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
WorkflowRun, WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
WorkflowRunTriggeredFrom,
) )
@ -92,9 +92,7 @@ class WorkflowCycleManager:
*, *,
session: Session, session: Session,
workflow_id: str, workflow_id: str,
user_id: str, ) -> WorkflowExecution:
created_by_role: CreatorUserRole,
) -> WorkflowRun:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt) workflow = session.scalar(workflow_stmt)
if not workflow: if not workflow:
@ -113,38 +111,26 @@ class WorkflowCycleManager:
continue continue
inputs[f"sys.{key.value}"] = value inputs[f"sys.{key.value}"] = value
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# handle special values # handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run # init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
execution = WorkflowExecution.new(
workflow_run = WorkflowRun() id=execution_id,
workflow_run.id = workflow_run_id workflow_id=workflow.id,
workflow_run.tenant_id = workflow.tenant_id sequence_number=new_sequence_number,
workflow_run.app_id = workflow.app_id type=WorkflowType(workflow.type),
workflow_run.sequence_number = new_sequence_number workflow_version=workflow.version,
workflow_run.workflow_id = workflow.id graph=workflow.graph_dict,
workflow_run.type = workflow.type inputs=inputs,
workflow_run.triggered_from = triggered_from.value started_at=datetime.now(UTC).replace(tzinfo=None),
workflow_run.version = workflow.version )
workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = created_by_role
workflow_run.created_by = user_id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_run)
return workflow_run self._workflow_execution_repository.save(execution)
return execution
def _handle_workflow_run_success( def _handle_workflow_run_success(
self, self,
@ -462,20 +448,18 @@ class WorkflowCycleManager:
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
*, *,
session: Session,
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_execution: WorkflowExecution,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
_ = session
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_execution.id,
data=WorkflowStartStreamResponse.Data( data=WorkflowStartStreamResponse.Data(
id=workflow_run.id, id=workflow_execution.id,
workflow_id=workflow_run.workflow_id, workflow_id=workflow_execution.workflow_id,
sequence_number=workflow_run.sequence_number, sequence_number=workflow_execution.sequence_number,
inputs=dict(workflow_run.inputs_dict or {}), inputs=workflow_execution.inputs,
created_at=int(workflow_run.created_at.timestamp()), created_at=int(workflow_execution.started_at.timestamp()),
), ),
) )

Loading…
Cancel
Save