diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index fc6556dfb5..f4228fa704 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -486,21 +487,53 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - # chatbot app - runner = AdvancedChatAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - dialogue_count=self._dialogue_count, - variable_loader=variable_loader, + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id)) + if app is None: + raise ValueError("App not found") + + # chatbot app + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + dialogue_count=self._dialogue_count, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + app=app, + ) + try: runner.run() except GenerateTaskStoppedError: pass diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 9a1c87f8b5..f6f06429d8 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -29,8 +29,9 @@ from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models import Workflow from models.enums import UserFrom -from models.model import App, Conversation, EndUser, Message, MessageAnnotation +from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) @@ -43,21 +44,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, dialogue_count: int, variable_loader: VariableLoader, + workflow: Workflow, + system_user_id: str, + app: App, ) -> None: - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self.system_user_id = system_user_id + self._app = app def run(self) -> None: app_config = self.application_generate_entity.app_config @@ -86,14 +95,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), ) @@ -104,7 +113,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # moderation if self.handle_input_moderation( - app_record=app_record, + app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, @@ -114,7 +123,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # annotation reply if self.handle_annotation_reply( - app_record=app_record, + app_record=self._app, message=self.message, query=query, app_generate_entity=self.application_generate_entity, @@ -134,7 +143,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) - for variable in workflow.conversation_variables + for variable in self._workflow.conversation_variables ] session.add_all(db_conversation_variables) # Convert database entities to variables. @@ -147,7 +156,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query=query, files=files, conversation_id=self.conversation.id, - user_id=user_id, + user_id=self.system_user_id, dialogue_count=self._dialogue_count, app_id=app_config.app_id, workflow_id=app_config.workflow_id, @@ -158,25 +167,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, + environment_variables=self._workflow.environment_variables, # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) db.session.close() # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index eeca9bb503..086ab25a9d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -445,15 +446,41 @@ class WorkflowAppGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # workflow app - runner = WorkflowAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id, - variable_loader=variable_loader, + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + ) + try: runner.run() except GenerateTaskStoppedError: pass @@ -471,8 +498,6 @@ class WorkflowAppGenerator(BaseAppGenerator): except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() def _handle_response( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 3a66ffa578..4f4c1460ae 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.enums import UserFrom -from models.model import App, EndUser -from models.workflow import WorkflowType +from models.workflow import Workflow, WorkflowType logger = logging.getLogger(__name__) @@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, + workflow: Workflow, + system_user_id: str, ) -> None: - """ - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param workflow_thread_pool_id: workflow thread pool id - """ - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.workflow_thread_pool_id = workflow_thread_pool_id - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self._sys_user_id = system_user_id def run(self) -> None: """ @@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - user_id = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id - - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() - if not app_record: - raise ValueError("App not found") - - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - db.session.close() - workflow_callbacks: list[WorkflowCallback] = [] if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) @@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=self.application_generate_entity.single_loop_run.inputs, ) @@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): system_inputs = SystemVariable( files=files, - user_id=user_id, + user_id=self._sys_user_id, app_id=app_config.app_id, workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, @@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, + environment_variables=self._workflow.environment_variables, conversation_variables=[], ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index db8c335f62..ec63aa1f5c 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,5 +1,7 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast + +from sqlalchemy.orm import Session from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.queue_entities import ( @@ -65,17 +67,20 @@ from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db -from models.model import App from models.workflow import Workflow class WorkflowBasedAppRunner: - def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: - self.queue_manager = queue_manager + def __init__( + self, + *, + queue_manager: AppQueueManager, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + app_id: str, + ) -> None: + self._queue_manager = queue_manager self._variable_loader = variable_loader - - def _get_app_id(self) -> str: - raise NotImplementedError("not implemented") + self._app_id = app_id def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -692,21 +697,24 @@ class WorkflowBasedAppRunner: ) ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow - def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + + def _save_draft_var_for_event(self, event: BaseNodeEvent): + run_result = event.route_node_state.node_run_result + if run_result is None: + return + process_data = run_result.process_data + outputs = run_result.outputs + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=self._app_id, + node_id=event.node_id, + node_type=event.node_type, + # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. + invoke_from=self._queue_manager._invoke_from, + node_execution_id=event.id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, + ) + draft_var_saver.save(process_data=process_data, outputs=outputs)