feat(api): Prefill conversation variables in draft workflow and update related services

pull/20699/head
QuantumGhost 12 months ago
parent 43cae8fc0b
commit 72561300ec

@ -263,6 +263,14 @@ class ConversationVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)

@ -36,7 +36,7 @@ from models import Account, App, Conversation, EndUser, Message, Workflow, Workf
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
@ -266,6 +266,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
@ -346,6 +348,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,

@ -27,12 +27,12 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
@ -187,7 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: DraftVarLoader = DUMMY_VARIABLE_LOADER,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -308,6 +308,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
@ -389,6 +391,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session)
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
@ -411,7 +415,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: DraftVarLoader,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""

@ -2,6 +2,7 @@ import dataclasses
import datetime
import logging
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar
from sqlalchemy import Engine, orm
@ -235,12 +236,21 @@ class WorkflowDraftVariableService:
return None
return segment.value
def create_conversation_and_set_conversation_variables(
def get_or_create_conversation(
self,
account_id: str,
app: App,
workflow: Workflow,
) -> str:
"""
get_or_create_conversation creates and returns the ID of a conversation for debugging.
If a conversation already exists, as determined by the following criteria, its ID is returned:
- The system variable `sys.conversation_id` exists in the draft variable table, and
- A corresponding conversation record is found in the database.
If no such conversation exists, a new conversation is created and its ID is returned.
"""
conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
if conv_id is not None:
@ -276,6 +286,10 @@ class WorkflowDraftVariableService:
self._session.add(conversation)
self._session.flush()
return conversation.id
def prefill_conversation_variable_default_values(self, workflow: Workflow):
""""""
draft_conv_vars: list[WorkflowDraftVariable] = []
for conv_var in workflow.conversation_variables:
draft_var = WorkflowDraftVariable.new_conversation_variable(
@ -285,14 +299,25 @@ class WorkflowDraftVariableService:
description=conv_var.description,
)
draft_conv_vars.append(draft_var)
_batch_upsert_draft_varaible(
self._session,
draft_conv_vars,
policy=_UpsertPolicy.IGNORE,
)
_batch_upsert_draft_varaible(self._session, draft_conv_vars)
return conversation.id
class _UpsertPolicy(StrEnum):
IGNORE = "ignore"
OVERWRITE = "overwrite"
def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[WorkflowDraftVariable]):
def _batch_upsert_draft_varaible(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
if not draft_vars:
return
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
#
# 1. The variable saving process involves writing multiple rows to the
@ -313,18 +338,23 @@ def _batch_upsert_draft_varaible(session: Session, draft_vars: Sequence[Workflow
# For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
# insert operations instead of the ORM layer.
stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
set_={
"updated_at": stmt.excluded.updated_at,
"last_edited_at": stmt.excluded.last_edited_at,
"description": stmt.excluded.description,
"value_type": stmt.excluded.value_type,
"value": stmt.excluded.value,
"visible": stmt.excluded.visible,
"editable": stmt.excluded.editable,
},
)
if policy == _UpsertPolicy.OVERWRITE:
stmt = stmt.on_conflict_do_update(
index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
set_={
"updated_at": stmt.excluded.updated_at,
"last_edited_at": stmt.excluded.last_edited_at,
"description": stmt.excluded.description,
"value_type": stmt.excluded.value_type,
"value": stmt.excluded.value,
"visible": stmt.excluded.visible,
"editable": stmt.excluded.editable,
},
)
elif _UpsertPolicy.IGNORE:
stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
else:
raise Exception("Invalid value for update policy.")
session.execute(stmt)

@ -329,8 +329,10 @@ class WorkflowService:
# TODO(QuantumGhost): We may get rid of the `list_conversation_variables`
# here, and rely on `DraftVarLoader` to load conversation variables.
with Session(bind=db.engine) as session:
draft_var_srv = WorkflowDraftVariableService(session)
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
conv_vars_models = draft_var_srv.list_conversation_variables(app_id=app_model.id)
conv_vars = [
@ -343,7 +345,7 @@ class WorkflowService:
if node_type == NodeType.START:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
conversation_id = draft_var_srv.create_conversation_and_set_conversation_variables(
conversation_id = draft_var_srv.get_or_create_conversation(
account_id=account.id,
app=app_model,
workflow=draft_workflow,
@ -369,7 +371,10 @@ class WorkflowService:
conversation_variables=[],
)
variable_loader = DraftVarLoader(engine=db.engine, app_id=app_model.id)
variable_loader = DraftVarLoader(
engine=db.engine,
app_id=app_model.id,
)
run = WorkflowEntry.single_step_run(
workflow=draft_workflow,

Loading…
Cancel
Save