From 72561300ecf0e9e8c046484b1547ecd99ab37e40 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Mon, 2 Jun 2025 17:55:32 +0800 Subject: [PATCH] feat(api): Prefill conversation variables in draft workflow and update related services --- .../console/app/workflow_draft_variable.py | 8 +++ .../app/apps/advanced_chat/app_generator.py | 6 +- api/core/app/apps/workflow/app_generator.py | 12 ++-- .../workflow_draft_variable_service.py | 64 ++++++++++++++----- api/services/workflow_service.py | 9 ++- 5 files changed, 75 insertions(+), 24 deletions(-) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 028ea84114..2a9f8e1070 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -263,6 +263,14 @@ class ConversationVariableCollectionApi(Resource): @_api_prerequisite @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) def get(self, app_model: App): + # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table + # so their IDs can be returned to the caller. + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") + draft_var_srv = WorkflowDraftVariableService(db.session) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 095b42e66a..2a262526ea 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -36,7 +36,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError -from services.workflow_draft_variable_service import DraftVarLoader +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -266,6 +266,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): engine=db.engine, app_id=application_generate_entity.app_config.app_id, ) + draft_var_srv = WorkflowDraftVariableService(db.session) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, @@ -346,6 +348,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): engine=db.engine, app_id=application_generate_entity.app_config.app_id, ) + draft_var_srv = WorkflowDraftVariableService(db.session) + draft_var_srv.prefill_conversation_variable_default_values(workflow) return self._generate( workflow=workflow, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index fd66dc9fe5..e2ba9c8266 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,12 +27,12 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom -from services.workflow_draft_variable_service import DraftVarLoader +from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService logger = logging.getLogger(__name__) @@ -187,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, - variable_loader: DraftVarLoader = DUMMY_VARIABLE_LOADER, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -308,6 +308,8 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + draft_var_srv = WorkflowDraftVariableService(db.session) + draft_var_srv.prefill_conversation_variable_default_values(workflow) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, @@ -389,6 +391,8 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + draft_var_srv = WorkflowDraftVariableService(db.session) + draft_var_srv.prefill_conversation_variable_default_values(workflow) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, @@ -411,7 +415,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, context: contextvars.Context, - variable_loader: DraftVarLoader, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f1d3b1109a..3c3668fb36 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -2,6 +2,7 @@ import dataclasses import datetime import logging from collections.abc import Mapping, Sequence +from enum import StrEnum from typing import Any, ClassVar from sqlalchemy import Engine, orm @@ -235,12 +236,21 @@ class WorkflowDraftVariableService: return None return segment.value - def create_conversation_and_set_conversation_variables( + def get_or_create_conversation( self, account_id: str, app: App, workflow: Workflow, ) -> str: + """ + get_or_create_conversation creates and returns the ID of a conversation for debugging. + + If a conversation already exists, as determined by the following criteria, its ID is returned: + - The system variable `sys.conversation_id` exists in the draft variable table, and + - A corresponding conversation record is found in the database. + + If no such conversation exists, a new conversation is created and its ID is returned. + """ conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) if conv_id is not None: @@ -276,6 +286,10 @@ class WorkflowDraftVariableService: self._session.add(conversation) self._session.flush() + return conversation.id + + def prefill_conversation_variable_default_values(self, workflow: Workflow): + """""" draft_conv_vars: list[WorkflowDraftVariable] = [] for conv_var in workflow.conversation_variables: draft_var = WorkflowDraftVariable.new_conversation_variable( @@ -285,14 +299,25 @@ class WorkflowDraftVariableService: description=conv_var.description, ) draft_conv_vars.append(draft_var) + _batch_upsert_draft_varaible( + self._session, + draft_conv_vars, + policy=_UpsertPolicy.IGNORE, + ) - _batch_upsert_draft_varaible(self._session, draft_conv_vars) - return conversation.id +class _UpsertPolicy(StrEnum): + IGNORE = "ignore" + OVERWRITE = "overwrite" -def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[WorkflowDraftVariable]): + +def _batch_upsert_draft_varaible( + session: Session, + draft_vars: Sequence[WorkflowDraftVariable], + policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE, +) -> None: if not draft_vars: - return + return None # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: # # 1. The variable saving process involves writing multiple rows to the @@ -313,18 +338,23 @@ def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[Workflow # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific # insert operations instead of the ORM layer. stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) - stmt = stmt.on_conflict_do_update( - index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), - set_={ - "updated_at": stmt.excluded.updated_at, - "last_edited_at": stmt.excluded.last_edited_at, - "description": stmt.excluded.description, - "value_type": stmt.excluded.value_type, - "value": stmt.excluded.value, - "visible": stmt.excluded.visible, - "editable": stmt.excluded.editable, - }, - ) + if policy == _UpsertPolicy.OVERWRITE: + stmt = stmt.on_conflict_do_update( + index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + set_={ + "updated_at": stmt.excluded.updated_at, + "last_edited_at": stmt.excluded.last_edited_at, + "description": stmt.excluded.description, + "value_type": stmt.excluded.value_type, + "value": stmt.excluded.value, + "visible": stmt.excluded.visible, + "editable": stmt.excluded.editable, + }, + ) + elif _UpsertPolicy.IGNORE: + stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) + else: + raise Exception("Invalid value for update policy.") session.execute(stmt) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index e78d00f98c..8a71fcf3aa 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -329,8 +329,10 @@ class WorkflowService: # TODO(QuantumGhost): We may get rid of the `list_conversation_variables` # here, and rely on `DraftVarLoader` to load conversation variables. + with Session(bind=db.engine) as session: draft_var_srv = WorkflowDraftVariableService(session) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) conv_vars_models = draft_var_srv.list_conversation_variables(app_id=app_model.id) conv_vars = [ @@ -343,7 +345,7 @@ class WorkflowService: if node_type == NodeType.START: with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - conversation_id = draft_var_srv.create_conversation_and_set_conversation_variables( + conversation_id = draft_var_srv.get_or_create_conversation( account_id=account.id, app=app_model, workflow=draft_workflow, @@ -369,7 +371,10 @@ class WorkflowService: conversation_variables=[], ) - variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id) + variable_loader = DraftVarLoader( + engine=db.engine, + app_id=app_model.id, + ) run = WorkflowEntry.single_step_run( workflow=draft_workflow,