feat(api): Prefill conversation variables in draft workflow and update related services

pull/20699/head
QuantumGhost 12 months ago
parent 43cae8fc0b
commit 72561300ec

@ -263,6 +263,14 @@ class ConversationVariableCollectionApi(Resource):
@_api_prerequisite @_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS) @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App): def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)

@ -36,7 +36,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -266,6 +266,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
) )
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -346,6 +348,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
) )
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,

@ -27,12 +27,12 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -187,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
variable_loader: DraftVarLoader = DUMMY_VARIABLE_LOADER, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -308,6 +308,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -389,6 +391,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -411,7 +415,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
context: contextvars.Context, context: contextvars.Context,
variable_loader: DraftVarLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """

@ -2,6 +2,7 @@ import dataclasses
import datetime import datetime
import logging import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar from typing import Any, ClassVar
from sqlalchemy import Engine, orm from sqlalchemy import Engine, orm
@ -235,12 +236,21 @@ class WorkflowDraftVariableService:
return None return None
return segment.value return segment.value
def create_conversation_and_set_conversation_variables( def get_or_create_conversation(
self, self,
account_id: str, account_id: str,
app: App, app: App,
workflow: Workflow, workflow: Workflow,
) -> str: ) -> str:
"""
get_or_create_conversation creates and returns the ID of a conversation for debugging.
If a conversation already exists, as determined by the following criteria, its ID is returned:
- The system variable `sys.conversation_id` exists in the draft variable table, and
- A corresponding conversation record is found in the database.
If no such conversation exists, a new conversation is created and its ID is returned.
"""
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
if conv_id is not None: if conv_id is not None:
@ -276,6 +286,10 @@ class WorkflowDraftVariableService:
self._session.add(conversation) self._session.add(conversation)
self._session.flush() self._session.flush()
return conversation.id
def prefill_conversation_variable_default_values(self, workflow: Workflow):
""""""
draft_conv_vars: list[WorkflowDraftVariable] = [] draft_conv_vars: list[WorkflowDraftVariable] = []
for conv_var in workflow.conversation_variables: for conv_var in workflow.conversation_variables:
draft_var = WorkflowDraftVariable.new_conversation_variable( draft_var = WorkflowDraftVariable.new_conversation_variable(
@ -285,14 +299,25 @@ class WorkflowDraftVariableService:
description=conv_var.description, description=conv_var.description,
) )
draft_conv_vars.append(draft_var) draft_conv_vars.append(draft_var)
_batch_upsert_draft_varaible(
self._session,
draft_conv_vars,
policy=_UpsertPolicy.IGNORE,
)
_batch_upsert_draft_varaible(self._session, draft_conv_vars)
return conversation.id
class _UpsertPolicy(StrEnum):
IGNORE = "ignore"
OVERWRITE = "overwrite"
def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[WorkflowDraftVariable]):
def _batch_upsert_draft_varaible(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
if not draft_vars: if not draft_vars:
return return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons: # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
# #
# 1. The variable saving process involves writing multiple rows to the # 1. The variable saving process involves writing multiple rows to the
@ -313,6 +338,7 @@ def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[Workflow
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
# insert operations instead of the ORM layer. # insert operations instead of the ORM layer.
stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
if policy == _UpsertPolicy.OVERWRITE:
stmt = stmt.on_conflict_do_update( stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
set_={ set_={
@ -325,6 +351,10 @@ def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[Workflow
"editable": stmt.excluded.editable, "editable": stmt.excluded.editable,
}, },
) )
elif _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
else:
raise Exception("Invalid value for update policy.")
session.execute(stmt) session.execute(stmt)

@ -329,8 +329,10 @@ class WorkflowService:
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables` # TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
# here, and rely on `DraftVarLoader` to load conversation variables. # here, and rely on `DraftVarLoader` to load conversation variables.
with Session(bind=db.engine) as session: with Session(bind=db.engine) as session:
draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
conv_vars_models = draft_var_srv.list_conversation_variables(app_id=app_model.id) conv_vars_models = draft_var_srv.list_conversation_variables(app_id=app_model.id)
conv_vars = [ conv_vars = [
@ -343,7 +345,7 @@ class WorkflowService:
if node_type == NodeType.START: if node_type == NodeType.START:
with Session(bind=db.engine) as session, session.begin(): with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv = WorkflowDraftVariableService(session)
conversation_id = draft_var_srv.create_conversation_and_set_conversation_variables( conversation_id = draft_var_srv.get_or_create_conversation(
account_id=account.id, account_id=account.id,
app=app_model, app=app_model,
workflow=draft_workflow, workflow=draft_workflow,
@ -369,7 +371,10 @@ class WorkflowService:
conversation_variables=[], conversation_variables=[],
) )
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id) variable_loader = DraftVarLoader(
engine=db.engine,
app_id=app_model.id,
)
run = WorkflowEntry.single_step_run( run = WorkflowEntry.single_step_run(
workflow=draft_workflow, workflow=draft_workflow,

Loading…
Cancel
Save