refactor(app_runner): Move db query out

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21739/head
-LAN- 7 months ago
parent 7edd48146b
commit a3997933f5
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
@ -486,21 +487,53 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
""" """
with preserve_flask_contexts(flask_app, context_vars=context): with preserve_flask_contexts(flask_app, context_vars=context):
try: # get conversation and message
# get conversation and message conversation = self._get_conversation(conversation_id)
conversation = self._get_conversation(conversation_id) message = self._get_message(message_id)
message = self._get_message(message_id)
with Session(db.engine, expire_on_commit=False) as session:
# chatbot app workflow = session.scalar(
runner = AdvancedChatAppRunner( select(Workflow).where(
application_generate_entity=application_generate_entity, Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
queue_manager=queue_manager, Workflow.app_id == application_generate_entity.app_config.app_id,
conversation=conversation, Workflow.id == application_generate_entity.app_config.workflow_id,
message=message, )
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
) )
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
if app is None:
raise ValueError("App not found")
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
app=app,
)
try:
runner.run() runner.run()
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
pass pass

@ -29,8 +29,9 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message, MessageAnnotation from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,21 +44,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def __init__( def __init__(
self, self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
dialogue_count: int, dialogue_count: int,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
app: App,
) -> None: ) -> None:
super().__init__(queue_manager, variable_loader) super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
self.message = message self.message = message
self._dialogue_count = dialogue_count self._dialogue_count = dialogue_count
self._workflow = workflow
def _get_app_id(self) -> str: self.system_user_id = system_user_id
return self.application_generate_entity.app_config.app_id self._app = app
def run(self) -> None: def run(self) -> None:
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
@ -86,14 +95,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested # if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id, node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
) )
@ -104,7 +113,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation # moderation
if self.handle_input_moderation( if self.handle_input_moderation(
app_record=app_record, app_record=self._app,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
@ -114,7 +123,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# annotation reply # annotation reply
if self.handle_annotation_reply( if self.handle_annotation_reply(
app_record=app_record, app_record=self._app,
message=self.message, message=self.message,
query=query, query=query,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
@ -134,7 +143,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable( ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
) )
for variable in workflow.conversation_variables for variable in self._workflow.conversation_variables
] ]
session.add_all(db_conversation_variables) session.add_all(db_conversation_variables)
# Convert database entities to variables. # Convert database entities to variables.
@ -147,7 +156,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query, query=query,
files=files, files=files,
conversation_id=self.conversation.id, conversation_id=self.conversation.id,
user_id=user_id, user_id=self.system_user_id,
dialogue_count=self._dialogue_count, dialogue_count=self._dialogue_count,
app_id=app_config.app_id, app_id=app_config.app_id,
workflow_id=app_config.workflow_id, workflow_id=app_config.workflow_id,
@ -158,25 +167,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=workflow.environment_variables, environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), conversation_variables=cast(list[VariableUnion], conversation_variables),
) )
# init graph # init graph
graph = self._init_graph(graph_config=workflow.graph_dict) graph = self._init_graph(graph_config=self._workflow.graph_dict)
db.session.close() db.session.close()
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id, tenant_id=self._workflow.tenant_id,
app_id=workflow.app_id, app_id=self._workflow.app_id,
workflow_id=workflow.id, workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(workflow.type), workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph, graph=graph,
graph_config=workflow.graph_dict, graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT

@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
@ -445,15 +446,41 @@ class WorkflowAppGenerator(BaseAppGenerator):
""" """
with preserve_flask_contexts(flask_app, context_vars=context): with preserve_flask_contexts(flask_app, context_vars=context):
try: with Session(db.engine, expire_on_commit=False) as session:
# workflow app workflow = session.scalar(
runner = WorkflowAppRunner( select(Workflow).where(
application_generate_entity=application_generate_entity, Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
queue_manager=queue_manager, Workflow.app_id == application_generate_entity.app_config.app_id,
workflow_thread_pool_id=workflow_thread_pool_id, Workflow.id == application_generate_entity.app_config.workflow_id,
variable_loader=variable_loader, )
) )
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
try:
runner.run() runner.run()
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
pass pass
@ -471,8 +498,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response( def _handle_response(
self, self,

@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, EndUser from models.workflow import Workflow, WorkflowType
from models.workflow import WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__( def __init__(
self, self,
*,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
) -> None: ) -> None:
""" super().__init__(
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param queue_manager: application queue manager variable_loader=variable_loader,
:param workflow_thread_pool_id: workflow thread pool id app_id=application_generate_entity.app_config.app_id,
""" )
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
def _get_app_id(self) -> str: self._sys_user_id = system_user_id
return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
""" """
@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG: if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested # if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow, workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id, node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs, user_inputs=self.application_generate_entity.single_loop_run.inputs,
) )
@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
system_inputs = SystemVariable( system_inputs = SystemVariable(
files=files, files=files,
user_id=user_id, user_id=self._sys_user_id,
app_id=app_config.app_id, app_id=app_config.app_id,
workflow_id=app_config.workflow_id, workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id,
@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=workflow.environment_variables, environment_variables=self._workflow.environment_variables,
conversation_variables=[], conversation_variables=[],
) )
# init graph # init graph
graph = self._init_graph(graph_config=workflow.graph_dict) graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id, tenant_id=self._workflow.tenant_id,
app_id=workflow.app_id, app_id=self._workflow.app_id,
workflow_id=workflow.id, workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(workflow.type), workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph, graph=graph,
graph_config=workflow.graph_dict, graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT

@ -1,5 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional, cast from typing import Any, cast
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -65,17 +67,20 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow from models.workflow import Workflow
class WorkflowBasedAppRunner: class WorkflowBasedAppRunner:
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: def __init__(
self.queue_manager = queue_manager self,
*,
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
self._queue_manager = queue_manager
self._variable_loader = variable_loader self._variable_loader = variable_loader
self._app_id = app_id
def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
""" """
@ -692,21 +697,24 @@ class WorkflowBasedAppRunner:
) )
) )
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
def _save_draft_var_for_event(self, event: BaseNodeEvent):
run_result = event.route_node_state.node_run_result
if run_result is None:
return
process_data = run_result.process_data
outputs = run_result.outputs
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=self._app_id,
node_id=event.node_id,
node_type=event.node_type,
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
invoke_from=self._queue_manager._invoke_from,
node_execution_id=event.id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
)
draft_var_saver.save(process_data=process_data, outputs=outputs)

Loading…
Cancel
Save