diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c70bf84d2a..c4d1ef70d8 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource): raise Forbidden("no oauth available client config found for this tool provider") redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" - credentials = oauth_handler.get_credentials( + credentials_response = oauth_handler.get_credentials( tenant_id=tenant_id, user_id=user_id, plugin_id=plugin_id, @@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource): redirect_uri=redirect_uri, system_credentials=oauth_client_params, request=request, - ).credentials + ) + + credentials = credentials_response.credentials + expires_at = credentials_response.expires_at if not credentials: raise Exception("the plugin credentials failed") @@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource): tenant_id=tenant_id, provider=provider, credentials=dict(credentials), + expires_at=expires_at, api_type=CredentialType.OAUTH2, ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md deleted file mode 100644 index 7a57bb3658..0000000000 --- a/api/core/app/apps/README.md +++ /dev/null @@ -1,48 +0,0 @@ -## Guidelines for Database Connection Management in App Runner and Task Pipeline - -Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks. - -Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors. - -Examples: - -1. Creating a new record: - - ```python - app = App(id=1) - db.session.add(app) - db.session.commit() - db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close - - # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment). - - db.session.close() - - return app.id - ``` - -2. Fetching a record from the table: - - ```python - app = db.session.query(App).filter(App.id == app_id).first() - - created_at = app.created_at - - db.session.close() - - # Handle tasks (include long-running). - - ``` - -3. Updating a table field: - - ```python - app = db.session.query(App).filter(App.id == app_id).first() - - app.updated_at = time.utcnow() - db.session.commit() - db.session.close() - - return app_id - ``` - diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index fc6556dfb5..610a5bb278 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -486,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # get conversation and message - conversation = self._get_conversation(conversation_id) - message = self._get_message(message_id) - - # chatbot app - runner = AdvancedChatAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - dialogue_count=self._dialogue_count, - variable_loader=variable_loader, + # get conversation and message + conversation = self._get_conversation(conversation_id) + message = self._get_message(message_id) + + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id)) + if app is None: + raise ValueError("App not found") + + runner = AdvancedChatAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + dialogue_count=self._dialogue_count, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + app=app, + ) + try: runner.run() except GenerateTaskStoppedError: pass diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index af15324f46..80af9a3c60 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -9,13 +9,19 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.app_invoke_entities import ( + AdvancedChatAppGenerateEntity, + AppGenerateEntity, + InvokeFrom, +) from core.app.entities.queue_entities import ( QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent, ) +from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature from core.moderation.base import ModerationError +from core.moderation.input_moderation import InputModeration from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool @@ -23,8 +29,9 @@ from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db +from models import Workflow from models.enums import UserFrom -from models.model import App, Conversation, EndUser, Message +from models.model import App, Conversation, Message, MessageAnnotation from models.workflow import ConversationVariable, WorkflowType logger = logging.getLogger(__name__) @@ -37,21 +44,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: AdvancedChatAppGenerateEntity, queue_manager: AppQueueManager, conversation: Conversation, message: Message, dialogue_count: int, variable_loader: VariableLoader, + workflow: Workflow, + system_user_id: str, + app: App, ) -> None: - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message self._dialogue_count = dialogue_count - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self.system_user_id = system_user_id + self._app = app def run(self) -> None: app_config = self.application_generate_entity.app_config @@ -61,18 +76,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not app_record: raise ValueError("App not found") - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - user_id: str | None = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id - workflow_callbacks: list[WorkflowCallback] = [] if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) @@ -80,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), ) @@ -98,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # moderation if self.handle_input_moderation( - app_record=app_record, + app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, @@ -108,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # annotation reply if self.handle_annotation_reply( - app_record=app_record, + app_record=self._app, message=self.message, query=query, app_generate_entity=self.application_generate_entity, @@ -128,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ConversationVariable.from_variable( app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable ) - for variable in workflow.conversation_variables + for variable in self._workflow.conversation_variables ] session.add_all(db_conversation_variables) # Convert database entities to variables. @@ -141,7 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query=query, files=files, conversation_id=self.conversation.id, - user_id=user_id, + user_id=self.system_user_id, dialogue_count=self._dialogue_count, app_id=app_config.app_id, workflow_id=app_config.workflow_id, @@ -152,25 +155,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, + environment_variables=self._workflow.environment_variables, # Based on the definition of `VariableUnion`, # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) db.session.close() # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT @@ -241,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self._publish_event(QueueTextChunkEvent(text=text)) self._publish_event(QueueStopEvent(stopped_by=stopped_by)) + + def query_app_annotations_to_reply( + self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom + ) -> Optional[MessageAnnotation]: + """ + Query app annotations to reply + :param app_record: app record + :param message: message + :param query: query + :param user_id: user id + :param invoke_from: invoke from + :return: + """ + annotation_reply_feature = AnnotationReplyFeature() + return annotation_reply_feature.query( + app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from + ) + + def moderation_for_inputs( + self, + *, + app_id: str, + tenant_id: str, + app_generate_entity: AppGenerateEntity, + inputs: Mapping[str, Any], + query: str | None = None, + message_id: str, + ) -> tuple[bool, Mapping[str, Any], str]: + """ + Process sensitive_word_avoidance. + :param app_id: app id + :param tenant_id: tenant id + :param app_generate_entity: app generate entity + :param inputs: inputs + :param query: query + :param message_id: message id + :return: + """ + moderation_feature = InputModeration() + return moderation_feature.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_generate_entity.app_config, + inputs=dict(inputs), + query=query or "", + message_id=message_id, + trace_manager=app_generate_entity.trace_manager, + ) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index eeca9bb503..4c36f63c71 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError -from sqlalchemy.orm import sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, sessionmaker import contexts from configs import dify_config @@ -445,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): - try: - # workflow app - runner = WorkflowAppRunner( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - workflow_thread_pool_id=workflow_thread_pool_id, - variable_loader=variable_loader, + with Session(db.engine, expire_on_commit=False) as session: + workflow = session.scalar( + select(Workflow).where( + Workflow.tenant_id == application_generate_entity.app_config.tenant_id, + Workflow.app_id == application_generate_entity.app_config.app_id, + Workflow.id == application_generate_entity.app_config.workflow_id, + ) ) + if workflow is None: + raise ValueError("Workflow not found") + + # Determine system_user_id based on invocation source + is_external_api_call = application_generate_entity.invoke_from in { + InvokeFrom.WEB_APP, + InvokeFrom.SERVICE_API, + } + + if is_external_api_call: + # For external API calls, use end user's session ID + end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id)) + system_user_id = end_user.session_id if end_user else "" + else: + # For internal calls, use the original user ID + system_user_id = application_generate_entity.user_id + + runner = WorkflowAppRunner( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, + workflow=workflow, + system_user_id=system_user_id, + ) + try: runner.run() - except GenerateTaskStoppedError: + except GenerateTaskStoppedError as e: + logger.warning(f"Task stopped: {str(e)}") pass except InvokeAuthorizationError: queue_manager.publish_error( @@ -471,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator): except Exception as e: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) - finally: - db.session.close() def _handle_response( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 3a66ffa578..4f4c1460ae 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db from models.enums import UserFrom -from models.model import App, EndUser -from models.workflow import WorkflowType +from models.workflow import Workflow, WorkflowType logger = logging.getLogger(__name__) @@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): def __init__( self, + *, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, + workflow: Workflow, + system_user_id: str, ) -> None: - """ - :param application_generate_entity: application generate entity - :param queue_manager: application queue manager - :param workflow_thread_pool_id: workflow thread pool id - """ - super().__init__(queue_manager, variable_loader) + super().__init__( + queue_manager=queue_manager, + variable_loader=variable_loader, + app_id=application_generate_entity.app_config.app_id, + ) self.application_generate_entity = application_generate_entity self.workflow_thread_pool_id = workflow_thread_pool_id - - def _get_app_id(self) -> str: - return self.application_generate_entity.app_config.app_id + self._workflow = workflow + self._sys_user_id = system_user_id def run(self) -> None: """ @@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(WorkflowAppConfig, app_config) - user_id = None - if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() - if end_user: - user_id = end_user.session_id - else: - user_id = self.application_generate_entity.user_id - - app_record = db.session.query(App).filter(App.id == app_config.app_id).first() - if not app_record: - raise ValueError("App not found") - - workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id) - if not workflow: - raise ValueError("Workflow not initialized") - - db.session.close() - workflow_callbacks: list[WorkflowCallback] = [] if dify_config.DEBUG: workflow_callbacks.append(WorkflowLoggingCallback()) @@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): if self.application_generate_entity.single_iteration_run: # if only single iteration run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=self.application_generate_entity.single_iteration_run.inputs, ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( - workflow=workflow, + workflow=self._workflow, node_id=self.application_generate_entity.single_loop_run.node_id, user_inputs=self.application_generate_entity.single_loop_run.inputs, ) @@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): system_inputs = SystemVariable( files=files, - user_id=user_id, + user_id=self._sys_user_id, app_id=app_config.app_id, workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, @@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, - environment_variables=workflow.environment_variables, + environment_variables=self._workflow.environment_variables, conversation_variables=[], ) # init graph - graph = self._init_graph(graph_config=workflow.graph_dict) + graph = self._init_graph(graph_config=self._workflow.graph_dict) # RUN WORKFLOW workflow_entry = WorkflowEntry( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - workflow_id=workflow.id, - workflow_type=WorkflowType.value_of(workflow.type), + tenant_id=self._workflow.tenant_id, + app_id=self._workflow.app_id, + workflow_id=self._workflow.id, + workflow_type=WorkflowType.value_of(self._workflow.type), graph=graph, - graph_config=workflow.graph_dict, + graph_config=self._workflow.graph_dict, user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2f4d234ecd..948ea95e63 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -1,8 +1,7 @@ from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.apps.base_app_runner import AppRunner from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -65,18 +64,20 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry -from extensions.ext_database import db -from models.model import App from models.workflow import Workflow -class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: - self.queue_manager = queue_manager +class WorkflowBasedAppRunner: + def __init__( + self, + *, + queue_manager: AppQueueManager, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, + app_id: str, + ) -> None: + self._queue_manager = queue_manager self._variable_loader = variable_loader - - def _get_app_id(self) -> str: - raise NotImplementedError("not implemented") + self._app_id = app_id def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: """ @@ -693,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner): ) ) - def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]: - """ - Get workflow - """ - # fetch workflow by workflow_id - workflow = ( - db.session.query(Workflow) - .filter( - Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id - ) - .first() - ) - - # return workflow - return workflow - def _publish_event(self, event: AppQueueEvent) -> None: - self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index f7fd93be4a..331ac933c8 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -114,7 +114,8 @@ class LLMGenerator: ), ) - questions = output_parser.parse(cast(str, response.message.content)) + text_content = response.message.get_text_content() + questions = output_parser.parse(text_content) if text_content else [] except InvokeError: questions = [] except Exception: diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index c451bf514c..98cdc4c8b7 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser: json_obj = json.loads(action_match.group(0).strip()) else: json_obj = [] - return json_obj diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 9d010ae28d..83dc7f0525 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel): """ return not self.content + def get_text_content(self) -> str: + """ + Get text content from prompt message. + + :return: Text content as string, empty string if no text content + """ + if isinstance(self.content, str): + return self.content + elif isinstance(self.content, list): + text_parts = [] + for item in self.content: + if isinstance(item, TextPromptMessageContent): + text_parts.append(item.data) + return "".join(text_parts) + else: + return "" + @field_validator("content", mode="before") @classmethod def validate_content(cls, v): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 00253b8a11..16ab661092 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel): + metadata: Mapping[str, Any] = Field( + default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc." + ) + expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index d73e5d9f9e..7f022992ff 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient): except Exception as e: raise ValueError(f"Error getting credentials: {e}") + def refresh_credentials( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + redirect_uri: str, + system_credentials: Mapping[str, Any], + credentials: Mapping[str, Any], + ) -> PluginOAuthCredentialsResponse: + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + "credentials": credentials, + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for refresh credentials request.") + except Exception as e: + raise ValueError(f"Error refreshing credentials: {e}") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ Convert a Request object to raw HTTP data. diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7822bc389c..abbdf8de3f 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1,16 +1,19 @@ import json import logging import mimetypes -from collections.abc import Generator +import time +from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast +from pydantic import TypeAdapter from yarl import URL import contexts from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime @@ -244,12 +247,47 @@ class ToolManager: tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id ), ) + + # decrypt the credentials + decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials) + + # check if the credentials is expired + if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()): + # TODO: circular import + from services.tools.builtin_tools_manage_service import BuiltinToolManageService + + # refresh the credentials + tool_provider = ToolProviderID(provider_id) + provider_name = tool_provider.provider_name + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback" + system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id) + oauth_handler = OAuthHandler() + # refresh the credentials + refreshed_credentials = oauth_handler.refresh_credentials( + tenant_id=tenant_id, + user_id=builtin_provider.user_id, + plugin_id=tool_provider.plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=system_credentials or {}, + credentials=decrypted_credentials, + ) + # update the credentials + builtin_provider.encrypted_credentials = ( + TypeAdapter(dict[str, Any]) + .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials))) + .decode("utf-8") + ) + builtin_provider.expires_at = refreshed_credentials.expires_at + db.session.commit() + decrypted_credentials = refreshed_credentials.credentials + return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=encrypter.decrypt(builtin_provider.credentials), + credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 86d36f474d..f437ac841d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -317,7 +317,13 @@ class ToolNode(BaseNode): elif message.type == ToolInvokeMessage.MessageType.FILE: assert message.meta is not None assert isinstance(message.meta, dict) - assert "file" in message.meta and isinstance(message.meta["file"], File) + # Validate that meta contains a 'file' key + if "file" not in message.meta: + raise ToolNodeError("File message is missing 'file' key in meta") + + # Validate that the file is an instance of File + if not isinstance(message.meta["file"], File): + raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) elif message.type == ToolInvokeMessage.MessageType.LOG: assert isinstance(message.message, ToolInvokeMessage.LogMessage) diff --git a/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py new file mode 100644 index 0000000000..76d0cb2940 --- /dev/null +++ b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py @@ -0,0 +1,34 @@ +"""oauth_refresh_token + +Revision ID: 375fe79ead14 +Revises: 1a83934ad6d1 +Create Date: 2025-07-22 00:19:45.599636 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '375fe79ead14' +down_revision = '1a83934ad6d1' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_column('expires_at') + + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index a0b7e54175..8c91e91f0e 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -93,6 +93,7 @@ class BuiltinToolProvider(Base): credential_type: Mapped[str] = mapped_column( db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) + expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1")) @property def credentials(self) -> dict: diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 430575b532..b8e3ce2650 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -38,6 +38,7 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + __DEFAULT_EXPIRES_AT__ = 2147483647 @staticmethod def delete_custom_oauth_client_params(tenant_id: str, provider: str): @@ -212,6 +213,7 @@ class BuiltinToolManageService: tenant_id: str, provider: str, credentials: dict, + expires_at: int = -1, name: str | None = None, ): """ @@ -269,6 +271,9 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), credential_type=api_type.value, name=name, + expires_at=expires_at + if expires_at is not None + else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__, ) session.add(db_provider) diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index e0e256912e..c0126a0f4f 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -112,19 +112,27 @@ class MCPToolManageService: @classmethod def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + server_url = mcp_provider.decrypted_server_url + authed = mcp_provider.authed + try: - with MCPClient( - mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True - ) as mcp_client: + with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client: tools = mcp_client.list_tools() except MCPAuthError: raise ValueError("Please auth the tool first") except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") - mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) - mcp_provider.authed = True - mcp_provider.updated_at = datetime.now() - db.session.commit() + + try: + mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) + mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + mcp_provider.authed = True + mcp_provider.updated_at = datetime.now() + db.session.commit() + except Exception: + db.session.rollback() + raise + user = mcp_provider.load_user() return ToolProviderApiEntity( id=mcp_provider.id, @@ -160,22 +168,35 @@ class MCPToolManageService: server_identifier: str, ): mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - mcp_provider.updated_at = datetime.now() - mcp_provider.name = name - mcp_provider.icon = ( - json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon - ) - mcp_provider.server_identifier = server_identifier + + reconnect_result = None + encrypted_server_url = None + server_url_hash = None if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url: encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) - mcp_provider.server_url = encrypted_server_url server_url_hash = hashlib.sha256(server_url.encode()).hexdigest() if server_url_hash != mcp_provider.server_url_hash: - cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id) - mcp_provider.server_url_hash = server_url_hash + reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id) + try: + mcp_provider.updated_at = datetime.now() + mcp_provider.name = name + mcp_provider.icon = ( + json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon + ) + mcp_provider.server_identifier = server_identifier + + if encrypted_server_url is not None and server_url_hash is not None: + mcp_provider.server_url = encrypted_server_url + mcp_provider.server_url_hash = server_url_hash + + if reconnect_result: + mcp_provider.authed = reconnect_result["authed"] + mcp_provider.tools = reconnect_result["tools"] + mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"] + db.session.commit() except IntegrityError as e: db.session.rollback() @@ -187,6 +208,9 @@ class MCPToolManageService: if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") raise + except Exception: + db.session.rollback() + raise @classmethod def update_mcp_provider_credentials( @@ -207,23 +231,22 @@ class MCPToolManageService: db.session.commit() @classmethod - def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str): - """re-connect mcp provider""" + def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str): try: with MCPClient( - mcp_provider.decrypted_server_url, + server_url, provider_id, tenant_id, authed=False, for_list=True, ) as mcp_client: tools = mcp_client.list_tools() - mcp_provider.authed = True - mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools]) + return { + "authed": True, + "tools": json.dumps([tool.model_dump() for tool in tools]), + "encrypted_credentials": "{}", + } except MCPAuthError: - mcp_provider.authed = False - mcp_provider.tools = "[]" + return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"} except MCPError as e: raise ValueError(f"Failed to re-connect MCP server: {e}") from e - # reset credentials - mcp_provider.encrypted_credentials = "{}" diff --git a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py deleted file mode 100644 index 47c175acd7..0000000000 --- a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py +++ /dev/null @@ -1,248 +0,0 @@ -import threading -from unittest.mock import Mock, patch - -from core.app.entities.app_invoke_entities import ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit -from events.event_handlers.update_provider_when_message_created import ( - handle, - get_update_stats, -) -from models.provider import ProviderType -from sqlalchemy.exc import OperationalError - - -class TestProviderUpdateDeadlockPrevention: - """Test suite for deadlock prevention in Provider updates.""" - - def setup_method(self): - """Setup test fixtures.""" - self.mock_message = Mock() - self.mock_message.answer_tokens = 100 - - self.mock_app_config = Mock() - self.mock_app_config.tenant_id = "test-tenant-123" - - self.mock_model_conf = Mock() - self.mock_model_conf.provider = "openai" - - self.mock_system_config = Mock() - self.mock_system_config.current_quota_type = QuotaUnit.TOKENS - - self.mock_provider_config = Mock() - self.mock_provider_config.using_provider_type = ProviderType.SYSTEM - self.mock_provider_config.system_configuration = self.mock_system_config - - self.mock_provider_bundle = Mock() - self.mock_provider_bundle.configuration = self.mock_provider_config - - self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle - - self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) - self.mock_generate_entity.app_config = self.mock_app_config - self.mock_generate_entity.model_conf = self.mock_model_conf - - @patch("events.event_handlers.update_provider_when_message_created.db") - def test_consolidated_handler_basic_functionality(self, mock_db): - """Test that the consolidated handler performs both updates correctly.""" - # Setup mock query chain - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 # 1 row affected - - # Call the handler - handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - - # Verify db.session.query was called - assert mock_db.session.query.called - - # Verify commit was called - mock_db.session.commit.assert_called_once() - - # Verify no rollback was called - assert not mock_db.session.rollback.called - - @patch("events.event_handlers.update_provider_when_message_created.db") - def test_deadlock_retry_mechanism(self, mock_db): - """Test that deadlock errors trigger retry logic.""" - # Setup mock to raise deadlock error on first attempt, succeed on second - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - # First call raises deadlock, second succeeds - mock_db.session.commit.side_effect = [ - OperationalError("deadlock detected", None, None), - None, # Success on retry - ] - - # Call the handler - handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - - # Verify commit was called twice (original + retry) - assert mock_db.session.commit.call_count == 2 - - # Verify rollback was called once (after first failure) - mock_db.session.rollback.assert_called_once() - - @patch("events.event_handlers.update_provider_when_message_created.db") - @patch("events.event_handlers.update_provider_when_message_created.time.sleep") - def test_exponential_backoff_timing(self, mock_sleep, mock_db): - """Test that retry delays follow exponential backoff pattern.""" - # Setup mock to fail twice, succeed on third attempt - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - mock_db.session.commit.side_effect = [ - OperationalError("deadlock detected", None, None), - OperationalError("deadlock detected", None, None), - None, # Success on third attempt - ] - - # Call the handler - handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - - # Verify sleep was called twice with increasing delays - assert mock_sleep.call_count == 2 - - # First delay should be around 0.1s + jitter - first_delay = mock_sleep.call_args_list[0][0][0] - assert 0.1 <= first_delay <= 0.3 - - # Second delay should be around 0.2s + jitter - second_delay = mock_sleep.call_args_list[1][0][0] - assert 0.2 <= second_delay <= 0.4 - - def test_concurrent_handler_execution(self): - """Test that multiple handlers can run concurrently without deadlock.""" - results = [] - errors = [] - - def run_handler(): - try: - with patch( - "events.event_handlers.update_provider_when_message_created.db" - ) as mock_db: - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - handle( - self.mock_message, - application_generate_entity=self.mock_generate_entity, - ) - results.append("success") - except Exception as e: - errors.append(str(e)) - - # Run multiple handlers concurrently - threads = [] - for _ in range(5): - thread = threading.Thread(target=run_handler) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join(timeout=5) - - # Verify all handlers completed successfully - assert len(results) == 5 - assert len(errors) == 0 - - def test_performance_stats_tracking(self): - """Test that performance statistics are tracked correctly.""" - # Reset stats - stats = get_update_stats() - initial_total = stats["total_updates"] - - with patch( - "events.event_handlers.update_provider_when_message_created.db" - ) as mock_db: - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - # Call handler - handle( - self.mock_message, application_generate_entity=self.mock_generate_entity - ) - - # Check that stats were updated - updated_stats = get_update_stats() - assert updated_stats["total_updates"] == initial_total + 1 - assert updated_stats["successful_updates"] >= initial_total + 1 - - def test_non_chat_entity_ignored(self): - """Test that non-chat entities are ignored by the handler.""" - # Create a non-chat entity - mock_non_chat_entity = Mock() - mock_non_chat_entity.__class__.__name__ = "NonChatEntity" - - with patch( - "events.event_handlers.update_provider_when_message_created.db" - ) as mock_db: - # Call handler with non-chat entity - handle(self.mock_message, application_generate_entity=mock_non_chat_entity) - - # Verify no database operations were performed - assert not mock_db.session.query.called - assert not mock_db.session.commit.called - - @patch("events.event_handlers.update_provider_when_message_created.db") - def test_quota_calculation_tokens(self, mock_db): - """Test quota calculation for token-based quotas.""" - # Setup token-based quota - self.mock_system_config.current_quota_type = QuotaUnit.TOKENS - self.mock_message.answer_tokens = 150 - - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - # Call handler - handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - - # Verify update was called with token count - update_calls = mock_query.update.call_args_list - - # Should have at least one call with quota_used update - quota_update_found = False - for call in update_calls: - values = call[0][0] # First argument to update() - if "quota_used" in values: - quota_update_found = True - break - - assert quota_update_found - - @patch("events.event_handlers.update_provider_when_message_created.db") - def test_quota_calculation_times(self, mock_db): - """Test quota calculation for times-based quotas.""" - # Setup times-based quota - self.mock_system_config.current_quota_type = QuotaUnit.TIMES - - mock_query = Mock() - mock_db.session.query.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.update.return_value = 1 - - # Call handler - handle(self.mock_message, application_generate_entity=self.mock_generate_entity) - - # Verify update was called - assert mock_query.update.called - assert mock_db.session.commit.called