From 12c20ec7f6bc8882785f272804eefab48332ef36 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 20 Jun 2025 10:34:57 +0800 Subject: [PATCH 01/15] feat: plugin OAuth with stateful --- api/app.py | 2 +- .../console/workspace/tool_providers.py | 156 +++++++++++-- api/core/plugin/impl/oauth.py | 18 +- .../python/examples/github/provider/github.py | 67 ++++++ api/extensions/ext_celery.py | 6 + ...99310d2c25a6_add_tool_oauth_credentials.py | 66 ++++++ ...9_1133-222376193a49_multiple_credential.py | 39 ++++ ...9_1353-a9306e69af07_multiple_credential.py | 33 +++ ...9_1359-6835b906335f_multiple_credential.py | 33 +++ ...9_1359-e315d2a83984_multiple_credential.py | 33 +++ ...9_1511-110e30078dd3_multiple_credential.py | 53 +++++ api/models/tools.py | 69 +++++- api/services/plugin/oauth_service.py | 63 ++++- .../tools/builtin_tools_manage_service.py | 216 ++++++++++++++---- api/tool_oauth.http | 27 +++ 15 files changed, 809 insertions(+), 72 deletions(-) create mode 100644 api/dify-plugin-sdks/python/examples/github/provider/github.py create mode 100644 api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py create mode 100644 api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py create mode 100644 api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py create mode 100644 api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py create mode 100644 api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py create mode 100644 api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py create mode 100644 api/tool_oauth.http diff --git a/api/app.py b/api/app.py index 4f393f6c20..11decffe96 100644 --- a/api/app.py +++ b/api/app.py @@ -38,4 +38,4 @@ else: celery = app.extensions["celery"] if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001) + app.run(host="0.0.0.0", port=5001,debug=True) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2b1379bfb2..e3285a16c7 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,18 +1,27 @@ import io -from flask import send_file +from flask import redirect, request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse +from flask_restful import ( + Resource, + reqparse, +) from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.impl.oauth import OAuthHandler from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required +from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.tool_labels_service import ToolLabelsService @@ -108,17 +117,19 @@ class ToolBuiltinProviderUpdateApi(Resource): tenant_id = user.current_tenant_id parser = reqparse.RequestParser() + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() with Session(db.engine) as session: result = BuiltinToolManageService.update_builtin_tool_provider( - session=session, user_id=user_id, tenant_id=tenant_id, - provider_name=provider, credentials=args["credentials"], + credential_id=args["credential_id"], + name=args["name"] ) session.commit() return result @@ -555,9 +566,9 @@ class ToolBuiltinListApi(Resource): [ provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -576,9 +587,9 @@ class ToolApiListApi(Resource): [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -597,9 +608,9 @@ class ToolWorkflowListApi(Resource): [ provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -613,6 +624,121 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +class ToolPluginOAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + args = parser.parse_args() + provider = args["provider"] + plugin_id = args["plugin_id"] + + # todo check permission + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + # check if user client is configured and enabled then using user client + # if user client is not configured then using system client + tenant_id = user.current_tenant_id + user_id = user.id + + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + plugin_id=plugin_id, + ) + + oauth_handler = OAuthHandler() + context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id, + tenant_id=tenant_id, + plugin_id=plugin_id, + provider=provider) + # todo decrypt oauth params + oauth_params = plugin_oauth_config.oauth_params + oauth_params[ + 'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + + response = oauth_handler.get_authorization_url( + tenant_id, + user.id, + plugin_id, + provider, + system_credentials=oauth_params, + ) + return response.model_dump() + + +class ToolOAuthCallback(Resource): + + @setup_required + def get(self): + args = (reqparse + .RequestParser() + .add_argument("context_id", type=str, required=True, nullable=False, location="args") + .parse_args() + ) + context_id = args["context_id"] + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + user_id, tenant_id, plugin_id, provider = ( + context.get("user_id"), + context.get("tenant_id"), + context.get("plugin_id"), + context.get("provider"), + ) + oauth_handler = OAuthHandler() + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + plugin_id=plugin_id, + ) + oauth_params = plugin_oauth_config.oauth_params + oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + + credentials = oauth_handler.get_credentials( + tenant_id, + user_id, + plugin_id, + provider, + system_credentials=oauth_params, + request=request, + ) + + if not credentials: + raise Exception("no credentials found for this plugin") + + #TODO add credentials to database + return redirect(f"{dify_config.CONSOLE_WEB_URL}") + + +class ToolBuiltinProviderSetDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + return BuiltinToolManageService.set_default_provider( + tenant_id=current_user.current_tenant_id, + user_id=current_user.id, + provider=provider, + id=args["id"]) + + +# tool oauth +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool") +api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback") + # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") @@ -621,6 +747,8 @@ api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-prov api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource(ToolBuiltinProviderSetDefaultApi, + "/workspaces/current/tool-provider/builtin//set-default") api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..13873b6ba8 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def get_credentials( self, @@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/dify-plugin-sdks/python/examples/github/provider/github.py b/api/dify-plugin-sdks/python/examples/github/provider/github.py new file mode 100644 index 0000000000..36f2f85910 --- /dev/null +++ b/api/dify-plugin-sdks/python/examples/github/provider/github.py @@ -0,0 +1,67 @@ +import secrets +import urllib.parse +from collections.abc import Mapping +from typing import Any + +import requests +from dify_plugin import ToolProvider +from dify_plugin.errors.tool import ToolProviderCredentialValidationError +from werkzeug import Request + + +class GithubProvider(ToolProvider): + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _API_USER_URL = "https://api.github.com/user" + + def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str: + """ + Generate the authorization URL for the Github OAuth. + """ + state = secrets.token_urlsafe(16) + params = { + "client_id": system_credentials["client_id"], + "redirect_uri": system_credentials["redirect_uri"], + "scope": system_credentials.get("scope", "read:user"), + "state": state, + # Optionally: allow_signup, login, etc. + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]: + """ + Exchange code for access_token. + """ + code = request.args.get("code") + state = request.args.get("state") + if not code: + raise ValueError("No code provided") + # Optionally: validate state here + + data = { + "client_id": system_credentials["client_id"], + "client_secret": system_credentials["client_secret"], + "code": code, + "redirect_uri": system_credentials["redirect_uri"], + } + headers = {"Accept": "application/json"} + response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10) + response_json = response.json() + access_token = response_json.get("access_token") + if not access_token: + raise ValueError(f"Error in GitHub OAuth: {response_json}") + return {"access_token": access_token} + + def _validate_credentials(self, credentials: dict) -> None: + try: + if "access_token" not in credentials or not credentials.get("access_token"): + raise ToolProviderCredentialValidationError("GitHub API Access Token is required.") + headers = { + "Authorization": f"Bearer {credentials['access_token']}", + "Accept": "application/vnd.github+json", + } + response = requests.get(self._API_USER_URL, headers=headers, timeout=10) + if response.status_code != 200: + raise ToolProviderCredentialValidationError(response.json().get("message")) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index a837552007..14ec4ebae0 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,3 +1,4 @@ +import os from datetime import timedelta import pytz @@ -24,12 +25,17 @@ def init_app(app: DifyApp) -> Celery: }, } + + flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"} + celery_app = Celery( app.name, task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, task_ignore_result=True, + task_always_eager=flask_debugging, + task_eager_propagates=flask_debugging, ) # Add SSL options to the Celery configuration diff --git a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py b/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py new file mode 100644 index 0000000000..95e74571d5 --- /dev/null +++ b/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py @@ -0,0 +1,66 @@ +"""add tool oauth credentials + +Revision ID: 99310d2c25a6 +Revises: 4474872b0ee6 +Create Date: 2025-06-18 15:06:15.261915 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '99310d2c25a6' +down_revision = '4474872b0ee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_oauth_system_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') + ) + op.create_table('tool_oauth_user_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client') + ) + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.alter_column('credential_type', + existing_type=sa.VARCHAR(length=255), + type_=sa.String(length=32), + existing_nullable=False, + existing_server_default=sa.text("'api_key'::character varying")) + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) + batch_op.alter_column('credential_type', + existing_type=sa.String(length=32), + type_=sa.VARCHAR(length=255), + existing_nullable=False, + existing_server_default=sa.text("'api_key'::character varying")) + batch_op.drop_column('default') + + op.drop_table('tool_oauth_user_clients') + op.drop_table('tool_oauth_system_clients') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py new file mode 100644 index 0000000000..82e812cb3d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py @@ -0,0 +1,39 @@ +"""multiple credential + +Revision ID: 222376193a49 +Revises: 99310d2c25a6 +Create Date: 2025-06-19 11:33:46.400455 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '222376193a49' +down_revision = '99310d2c25a6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.drop_column('owner_type') + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py new file mode 100644 index 0000000000..216661550a --- /dev/null +++ b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: a9306e69af07 +Revises: 222376193a49 +Create Date: 2025-06-19 13:53:41.554159 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a9306e69af07' +down_revision = '222376193a49' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py new file mode 100644 index 0000000000..d90e0d178e --- /dev/null +++ b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: 6835b906335f +Revises: e315d2a83984 +Create Date: 2025-06-19 13:59:58.107955 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6835b906335f' +down_revision = 'e315d2a83984' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py new file mode 100644 index 0000000000..2f0caeaf0d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py @@ -0,0 +1,33 @@ +"""multiple credential + +Revision ID: e315d2a83984 +Revises: a9306e69af07 +Create Date: 2025-06-19 13:59:13.860523 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e315d2a83984' +down_revision = 'a9306e69af07' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py new file mode 100644 index 0000000000..84a5461a4d --- /dev/null +++ b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py @@ -0,0 +1,53 @@ +"""multiple credential + +Revision ID: 110e30078dd3 +Revises: 6835b906335f +Create Date: 2025-06-19 15:11:42.688478 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '110e30078dd3' +down_revision = '6835b906335f' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.String(length=512), + existing_nullable=False) + + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False)) + batch_op.alter_column('plugin_id', + existing_type=sa.UUID(), + type_=sa.String(length=512), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=512), + type_=sa.UUID(), + existing_nullable=False) + batch_op.drop_column('enabled') + + with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: + batch_op.alter_column('plugin_id', + existing_type=sa.String(length=512), + type_=sa.UUID(), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 03fbc3acb1..4b493e7596 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,3 +1,4 @@ +import enum import json from datetime import datetime from typing import Any, cast @@ -17,6 +18,65 @@ from .model import Account, App, Tenant from .types import StringUUID +class ToolProviderCredentialType(enum.StrEnum): + API_KEY = "api_key", + OAUTH2 = "oauth2", + + def is_editable(self): + return self == ToolProviderCredentialType.API_KEY + + @classmethod + def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType": + if credential_type == "api_key": + return cls.API_KEY + elif credential_type == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") + +# system level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthSystemClient(Base): + __tablename__ = "tool_oauth_system_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + # owner type, e.g., "system", "user" + + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params)) + + +# user level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthUserClient(Base): + __tablename__ = "tool_oauth_user_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + owner_type: Mapped[str] = mapped_column(db.Text, nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params)) + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. @@ -25,12 +85,11 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - # one tenant can only have one tool provider with the same name - db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column(db.String(256), nullable=False) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -45,6 +104,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + default: Mapped[bool] = mapped_column( + db.Boolean, nullable=False, server_default=db.text("false") + ) + # credential type, e.g., "api_key", "oauth2" + credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")) @property def credentials(self) -> dict: @@ -59,7 +123,6 @@ class ApiToolProvider(Base): __tablename__ = "tool_api_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..dcc14a8fad 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,62 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client + + +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `retrieve_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, microseconds = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, microseconds = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3ccd14415d..25d927f9f9 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,8 +2,6 @@ import json import logging from pathlib import Path -from sqlalchemy.orm import Session - from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -16,7 +14,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -109,63 +107,69 @@ class BuiltinToolManageService: @staticmethod def update_builtin_tool_provider( - session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + + if provider is None: + raise ValueError(f"you have not added provider {provider_name}") + + if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): + raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials") try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + # exclude oauth2 provider + if provider.credential_type != ToolProviderCredentialType.OAUTH2.value: + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(user_id, credentials) - # encrypt credentials - credentials = tool_configuration.encrypt(credentials) + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Decrypt and restore original credentials for masked values + credentials = BuiltinToolManageService._dec + rypt_and_restore_credentials( + provider_controller, tool_configuration, provider, credentials + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + + # update name if provided + if name is not None and provider.name != name: + provider.name = name + + db.session.commit() except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, ) as e: raise ValueError(str(e)) - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), - ) - - db.session.add(provider) - else: - provider.encrypted_credentials = json.dumps(credentials) + return {"result": "success"} - # delete cache - tool_configuration.delete_tool_credentials_cache() + @staticmethod + def add_builtin_tool_provider( + user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None + ): + """ + add builtin tool provider + """ + - db.session.commit() return {"result": "success"} @staticmethod @@ -214,6 +218,78 @@ class BuiltinToolManageService: return {"result": "success"} + @staticmethod + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + """ + set default provider + """ + # get provider + target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + default=True + ).update({"default": False}) + + # set new default provider + target_provider.default = True + db.session.commit() + return {"result": "success"} + + @staticmethod + def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str): + """ + fetch default provider + if there is no explicitly set default provider, return the oldest provider as default + """ + # 1. check if default provider exists + default_provider = db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + default=True + ).first() + if default_provider: + return default_provider + + # 2. if no default provider, set the oldest provider as default + oldest_provider = (db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name) + .order_by(BuiltinToolProvider.created_at) + .first() + ) + if oldest_provider: + return oldest_provider + + raise ValueError(f"no default provider found for {provider_name}") + + @staticmethod + def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str): + """ + get builtin tool provider + """ + user_client = db.session.query(ToolOAuthUserClient).filter_by( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + enabled=True, + ).first() + + if user_client: + plugin_oauth_config = user_client + else: + plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + + if plugin_oauth_config: + return plugin_oauth_config + + raise ValueError("no oauth available config found for this plugin") + @staticmethod def get_builtin_tool_provider_icon(provider: str): """ @@ -286,6 +362,15 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) + @staticmethod + def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: + provider = (db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first()) + return provider + @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: @@ -327,3 +412,42 @@ class BuiltinToolManageService: ) .first() ) + + @staticmethod + def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): + """ + Decrypt original credentials and restore masked values from the input credentials + + :param provider_controller: the provider controller + :param tool_configuration: the tool configuration encrypter + :param provider: the provider object from database + :param credentials: the input credentials from user + :return: the processed credentials with original values restored + """ + original_credentials = tool_configuration.decrypt(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: # type: ignore + credentials[name] = original_credentials[name] # type: ignore + + return credentials + + @staticmethod + def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): + """ + Validate and encrypt credentials, then save to database + + :param provider_controller: the provider controller + :param tool_configuration: the tool configuration encrypter + :param provider: the provider object from database + :param credentials: the credentials to encrypt and save + :param user_id: the user id for validation + """ + # validate credentials + provider_controller.validate_credentials(user_id, credentials) + # encrypt credentials + encrypted_credentials = tool_configuration.encrypt(credentials) + provider.encrypted_credentials = json.dumps(encrypted_credentials) + tool_configuration.delete_tool_credentials_cache() diff --git a/api/tool_oauth.http b/api/tool_oauth.http new file mode 100644 index 0000000000..9915472d03 --- /dev/null +++ b/api/tool_oauth.http @@ -0,0 +1,27 @@ + +@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q + + + +# set default credential +POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default +Host: 127.0.0.1:5001 +Content-Type: application/json +Authorization: Bearer {{accessToken}} + +{ + "id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f" +} +### + +# get oauth url +GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github +Host: 127.0.0.1:5001 +Authorization: Bearer {{accessToken}} + +### + +# get oauth token +GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw +Host: 127.0.0.1:5001 +Authorization: Bearer {{accessToken}} From b3a8dbe2f5853d233e693cbb435897f076dd87f5 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 23 Jun 2025 11:20:54 +0800 Subject: [PATCH 02/15] fix: typo --- .../tools/builtin_tools_manage_service.py | 3 +-- api/tool_oauth.http | 27 ------------------- 2 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 api/tool_oauth.http diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 25d927f9f9..31bc2e650d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -136,8 +136,7 @@ class BuiltinToolManageService: ) # Decrypt and restore original credentials for masked values - credentials = BuiltinToolManageService._dec - rypt_and_restore_credentials( + credentials = BuiltinToolManageService._decrypt_and_restore_credentials( provider_controller, tool_configuration, provider, credentials ) diff --git a/api/tool_oauth.http b/api/tool_oauth.http deleted file mode 100644 index 9915472d03..0000000000 --- a/api/tool_oauth.http +++ /dev/null @@ -1,27 +0,0 @@ - -@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q - - - -# set default credential -POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default -Host: 127.0.0.1:5001 -Content-Type: application/json -Authorization: Bearer {{accessToken}} - -{ - "id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f" -} -### - -# get oauth url -GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github -Host: 127.0.0.1:5001 -Authorization: Bearer {{accessToken}} - -### - -# get oauth token -GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw -Host: 127.0.0.1:5001 -Authorization: Bearer {{accessToken}} From 7f292dc261d865523d0df4e9454fb421f2baab33 Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 23 Jun 2025 12:49:18 +0800 Subject: [PATCH 03/15] fix: remove debugging flags --- api/app.py | 2 +- .../console/workspace/tool_providers.py | 3 +- api/extensions/ext_celery.py | 6 -- api/models/tools.py | 7 +- .../tools/builtin_tools_manage_service.py | 69 +++++++++++++++---- 5 files changed, 65 insertions(+), 22 deletions(-) diff --git a/api/app.py b/api/app.py index 11decffe96..4f393f6c20 100644 --- a/api/app.py +++ b/api/app.py @@ -38,4 +38,4 @@ else: celery = app.extensions["celery"] if __name__ == "__main__": - app.run(host="0.0.0.0", port=5001,debug=True) + app.run(host="0.0.0.0", port=5001) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e3285a16c7..a46071059f 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -127,9 +127,10 @@ class ToolBuiltinProviderUpdateApi(Resource): result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, + provider_name=provider, credentials=args["credentials"], credential_id=args["credential_id"], - name=args["name"] + name=args["name"], ) session.commit() return result diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 14ec4ebae0..a837552007 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -1,4 +1,3 @@ -import os from datetime import timedelta import pytz @@ -25,17 +24,12 @@ def init_app(app: DifyApp) -> Celery: }, } - - flask_debugging = os.environ.get("FLASK_DEBUG", "0").lower() in {"true", "1", "yes"} - celery_app = Celery( app.name, task_cls=FlaskTask, broker=dify_config.CELERY_BROKER_URL, backend=dify_config.CELERY_BACKEND, task_ignore_result=True, - task_always_eager=flask_debugging, - task_eager_propagates=flask_debugging, ) # Add SSL options to the Celery configuration diff --git a/api/models/tools.py b/api/models/tools.py index 4b493e7596..9e50cec52f 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -19,8 +19,11 @@ from .types import StringUUID class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key", - OAUTH2 = "oauth2", + API_KEY = "api_key" + OAUTH2 = "oauth2" + + def get_name(self): + return self.value.replace("_", " ").upper() def is_editable(self): return self == ToolProviderCredentialType.API_KEY diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 31bc2e650d..7dc3e4c0f8 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -118,12 +118,8 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f"you have not added provider {provider_name}") - if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): - raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials") - try: - # exclude oauth2 provider - if provider.credential_type != ToolProviderCredentialType.OAUTH2.value: + if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") @@ -139,11 +135,15 @@ class BuiltinToolManageService: credentials = BuiltinToolManageService._decrypt_and_restore_credentials( provider_controller, tool_configuration, provider, credentials ) - + # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) + else: + raise ValueError( + f"provider {provider_name} is not editable, you can only delete it and add a new one" + ) # update name if provided if name is not None and provider.name != name: @@ -162,15 +162,60 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None + user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None ): """ add builtin tool provider """ - + if name is None: + name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type) + + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + credential_type=type.value, + credentials=json.dumps(credentials), + name=name, + ) + + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + db.session.add(provider) return {"result": "success"} + @staticmethod + def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str: + """ + next name = max(provider_names) + 1 + """ + provider_names = db.session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, + credential_type=type.value, + ).all() + if not provider_names: + return f"{type.value} 1" + # OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1 + # if dont have number, then get name and add 1 + for provider_name in provider_names: + if provider_name.provider.startswith(type.value): + return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}" + return f"{type.value} 1" + + @staticmethod def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ @@ -416,7 +461,7 @@ class BuiltinToolManageService: def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): """ Decrypt original credentials and restore masked values from the input credentials - + :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database @@ -425,19 +470,19 @@ class BuiltinToolManageService: """ original_credentials = tool_configuration.decrypt(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - + # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: # type: ignore credentials[name] = original_credentials[name] # type: ignore - + return credentials @staticmethod def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): """ Validate and encrypt credentials, then save to database - + :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database From 5e7c5863ef13ecb03eae8f8f5516182b363572ce Mon Sep 17 00:00:00 2001 From: Harry Date: Mon, 23 Jun 2025 16:51:28 +0800 Subject: [PATCH 04/15] refactor(tool oauth): update api implementation --- README.md | 259 ------------- .../console/workspace/model_providers.py | 1 - .../console/workspace/tool_providers.py | 114 ++++-- api/core/tools/entities/api_entities.py | 13 +- api/core/tools/entities/tool_entities.py | 33 ++ api/core/tools/tool_manager.py | 25 +- ...9_1133-222376193a49_multiple_credential.py | 39 -- ...9_1353-a9306e69af07_multiple_credential.py | 33 -- ...9_1359-6835b906335f_multiple_credential.py | 33 -- ...9_1359-e315d2a83984_multiple_credential.py | 33 -- ...9_1511-110e30078dd3_multiple_credential.py | 53 --- ...025_06_24_1705-71f5020c6470_tool_oauth.py} | 47 ++- api/models/tools.py | 56 +-- api/services/plugin/oauth_service.py | 4 +- .../tools/builtin_tools_manage_service.py | 358 ++++++++++-------- api/services/tools/tools_transform_service.py | 16 +- 16 files changed, 386 insertions(+), 731 deletions(-) delete mode 100644 README.md delete mode 100644 api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py delete mode 100644 api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py rename api/migrations/versions/{2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py => 2025_06_24_1705-71f5020c6470_tool_oauth.py} (54%) diff --git a/README.md b/README.md deleted file mode 100644 index ca09adec08..0000000000 --- a/README.md +++ /dev/null @@ -1,259 +0,0 @@ -![cover-v5-optimized](./images/GitHub_README_if.png) - -

- 📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast -

- -

- Dify Cloud · - Self-hosting · - Documentation · - Dify edition overview -

- -

- - Static Badge - - Static Badge - - chat on Discord - - join Reddit - - follow on X(Twitter) - - follow on LinkedIn - - Docker Pulls - - Commits last month - - Issues closed - - Discussion posts -

- -

- README in English - 繁體中文文件 - 简体中文版自述文件 - 日本語のREADME - README en Español - README en Français - README tlhIngan Hol - README in Korean - README بالعربية - Türkçe README - README Tiếng Việt - README in Deutsch - README in বাংলা -

- -Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. - -## Quick start - -> Before installing Dify, make sure your machine meets the following minimum system requirements: -> -> - CPU >= 2 Core -> - RAM >= 4 GiB - -
- -The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: - -```bash -cd dify -cd docker -cp .env.example .env -docker compose up -d -``` - -After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. - -#### Seeking help - -Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues. - -> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code) - -## Key features - -**1. Workflow**: -Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. - -**2. Comprehensive model support**: -Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). - -![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) - -**3. Prompt IDE**: -Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app. - -**4. RAG Pipeline**: -Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats. - -**5. Agent capabilities**: -You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha. - -**6. LLMOps**: -Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations. - -**7. Backend-as-a-Service**: -All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic. - -## Feature Comparison - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
FeatureDify.AILangChainFlowiseOpenAI Assistants API
Programming ApproachAPI + App-orientedPython CodeApp-orientedAPI-oriented
Supported LLMsRich VarietyRich VarietyRich VarietyOpenAI-only
RAG Engine
Agent
Workflow
Observability
Enterprise Feature (SSO/Access control)
Local Deployment
- -## Using Dify - -- **Cloud
** - We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan. - -- **Self-hosting Dify Community Edition
** - Quickly get Dify running in your environment with this [starter guide](#quick-start). - Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions. - -- **Dify for enterprise / organizations
** - We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
- > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding. - -## Staying ahead - -Star Dify on GitHub and be instantly notified of new releases. - -![star-us](https://github.com/langgenius/dify/assets/13230914/b823edc1-6388-4e25-ad45-2f6b187adbb4) - -## Advanced Setup - -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). - -If you'd like to configure a highly-available setup, there are community-contributed [Helm Charts](https://helm.sh/) and YAML files which allow Dify to be deployed on Kubernetes. - -- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) -- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) -- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) -- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) -- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) - -#### Using Terraform for Deployment - -Deploy Dify to Cloud Platform with a single click using [terraform](https://www.terraform.io/) - -##### Azure Global - -- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform) - -##### Google Cloud - -- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) - -#### Using AWS CDK for Deployment - -Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - -##### AWS - -- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - -## Contributing - -For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -At the same time, please consider supporting Dify by sharing it on social media and at events and conferences. - -> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c). - -## Community & contact - -- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. -- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). -- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. -- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. - -**Contributors** - - - - - -## Star history - -[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) - -## Security disclosure - -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. - -## License - -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index ff0fcbda6e..32139781b0 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -35,7 +35,6 @@ class ModelProviderListApi(Resource): model_provider_service = ModelProviderService() provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) - return jsonable_encoder({"data": provider_list}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index a46071059f..a4839fe8a1 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -18,6 +18,7 @@ from controllers.console.wraps import ( ) from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import ToolProviderCredentialType from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -89,17 +90,47 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id + req = reqparse.RequestParser() + req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = req.parse_args() return BuiltinToolManageService.delete_builtin_tool_provider( - user_id, tenant_id, provider, + args["credential_id"], + ) + + +class ToolBuiltinProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + user = current_user + + user_id = user.id + tenant_id = user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=False, nullable=False, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + if args["type"] not in ToolProviderCredentialType.values(): + raise ValueError(f"Invalid credential type: {args['type']}") + + return BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + name=args["name"], + api_type=ToolProviderCredentialType.of(args["type"]), ) @@ -143,9 +174,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=tenant_id, - provider_name=provider, + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) ) @@ -567,9 +600,9 @@ class ToolBuiltinListApi(Resource): [ provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -588,9 +621,9 @@ class ToolApiListApi(Resource): [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -609,9 +642,9 @@ class ToolWorkflowListApi(Resource): [ provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - ) + user_id, + tenant_id, + ) ] ) @@ -656,14 +689,13 @@ class ToolPluginOAuthApi(Resource): ) oauth_handler = OAuthHandler() - context_id = OAuthProxyService.create_proxy_context(user_id=current_user.id, - tenant_id=tenant_id, - plugin_id=plugin_id, - provider=provider) + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider + ) # todo decrypt oauth params oauth_params = plugin_oauth_config.oauth_params - oauth_params[ - 'redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + oauth_params["redirect_uri"] = redirect_uri response = oauth_handler.get_authorization_url( tenant_id, @@ -676,14 +708,13 @@ class ToolPluginOAuthApi(Resource): class ToolOAuthCallback(Resource): - @setup_required def get(self): - args = (reqparse - .RequestParser() - .add_argument("context_id", type=str, required=True, nullable=False, location="args") - .parse_args() - ) + args = ( + reqparse.RequestParser() + .add_argument("context_id", type=str, required=True, nullable=False, location="args") + .parse_args() + ) context_id = args["context_id"] context = OAuthProxyService.use_proxy_context(context_id) if context is None: @@ -703,7 +734,8 @@ class ToolOAuthCallback(Resource): plugin_id=plugin_id, ) oauth_params = plugin_oauth_config.oauth_params - oauth_params['redirect_uri'] = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" + oauth_params["redirect_uri"] = redirect_uri credentials = oauth_handler.get_credentials( tenant_id, @@ -712,12 +744,20 @@ class ToolOAuthCallback(Resource): provider, system_credentials=oauth_params, request=request, - ) + ).credentials if not credentials: - raise Exception("no credentials found for this plugin") + raise Exception("the plugin credentials failed") - #TODO add credentials to database + # add credentials to database + BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=dict(credentials), + name=provider, + api_type=ToolProviderCredentialType.OAUTH2, + ) return redirect(f"{dify_config.CONSOLE_WEB_URL}") @@ -730,10 +770,8 @@ class ToolBuiltinProviderSetDefaultApi(Resource): parser.add_argument("id", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return BuiltinToolManageService.set_default_provider( - tenant_id=current_user.current_tenant_id, - user_id=current_user.id, - provider=provider, - id=args["id"]) + tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + ) # tool oauth @@ -746,10 +784,12 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") +api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") -api.add_resource(ToolBuiltinProviderSetDefaultApi, - "/workspaces/current/tool-provider/builtin//set-default") +api.add_resource( + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//set-default" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index b96c994cff..eaadd4d214 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -70,3 +70,14 @@ class ToolProviderApiEntity(BaseModel): "tools": tools, "labels": self.labels, } + + +class ToolProviderCredentialApiEntity(BaseModel): + id: str = Field(description="The unique id of the credential") + name: str = Field(description="The name of the credential") + provider: str = Field(description="The provider of the credential") + credential_type: ToolProviderCredentialType = Field(description="The type of the credential") + is_default: bool = Field( + default=False, description="Whether the credential is the default credential for the provider in the workspace" + ) + credentials: dict = Field(description="The credentials of the provider") diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 03047c0545..5094519b6f 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -434,3 +434,36 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() + + +class ToolProviderCredentialType(enum.StrEnum): + API_KEY = "api_key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == ToolProviderCredentialType.API_KEY: + return "API KEY" + elif self == ToolProviderCredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("_", " ").upper() + + def is_editable(self): + return self == ToolProviderCredentialType.API_KEY + + def is_validate_allowed(self): + return self == ToolProviderCredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "ToolProviderCredentialType": + type_name = credential_type.lower() + if type_name == "api_key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 0bfe6329b1..f25267dbf6 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -200,6 +200,7 @@ class ToolManager: (BuiltinToolProvider.provider == str(provider_id_entity)) | (BuiltinToolProvider.provider == provider_id_entity.provider_name), ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) @@ -209,6 +210,7 @@ class ToolManager: builtin_provider = ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) @@ -575,18 +577,27 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - # get builtin providers + + def get_builtin_providers(tenant_id): + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) + # get builtin providers + db_builtin_providers = get_builtin_providers(tenant_id) # rewrite db_builtin_providers for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id + db_provider.provider = str(ToolProviderID(db_provider.provider)) def find_db_builtin_provider(provider): return next((x for x in db_builtin_providers if x.provider == provider), None) diff --git a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py b/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py deleted file mode 100644 index 82e812cb3d..0000000000 --- a/api/migrations/versions/2025_06_19_1133-222376193a49_multiple_credential.py +++ /dev/null @@ -1,39 +0,0 @@ -"""multiple credential - -Revision ID: 222376193a49 -Revises: 99310d2c25a6 -Create Date: 2025-06-19 11:33:46.400455 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '222376193a49' -down_revision = '99310d2c25a6' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False)) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.drop_column('owner_type') - - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py b/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py deleted file mode 100644 index 216661550a..0000000000 --- a/api/migrations/versions/2025_06_19_1353-a9306e69af07_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: a9306e69af07 -Revises: 222376193a49 -Create Date: 2025-06-19 13:53:41.554159 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'a9306e69af07' -down_revision = '222376193a49' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default']) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py deleted file mode 100644 index d90e0d178e..0000000000 --- a/api/migrations/versions/2025_06_19_1359-6835b906335f_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: 6835b906335f -Revises: e315d2a83984 -Create Date: 2025-06-19 13:59:58.107955 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '6835b906335f' -down_revision = 'e315d2a83984' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py b/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py deleted file mode 100644 index 2f0caeaf0d..0000000000 --- a/api/migrations/versions/2025_06_19_1359-e315d2a83984_multiple_credential.py +++ /dev/null @@ -1,33 +0,0 @@ -"""multiple credential - -Revision ID: e315d2a83984 -Revises: a9306e69af07 -Create Date: 2025-06-19 13:59:13.860523 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'e315d2a83984' -down_revision = 'a9306e69af07' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py b/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py deleted file mode 100644 index 84a5461a4d..0000000000 --- a/api/migrations/versions/2025_06_19_1511-110e30078dd3_multiple_credential.py +++ /dev/null @@ -1,53 +0,0 @@ -"""multiple credential - -Revision ID: 110e30078dd3 -Revises: 6835b906335f -Create Date: 2025-06-19 15:11:42.688478 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '110e30078dd3' -down_revision = '6835b906335f' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.UUID(), - type_=sa.String(length=512), - existing_nullable=False) - - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False)) - batch_op.alter_column('plugin_id', - existing_type=sa.UUID(), - type_=sa.String(length=512), - existing_nullable=False) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.String(length=512), - type_=sa.UUID(), - existing_nullable=False) - batch_op.drop_column('enabled') - - with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op: - batch_op.alter_column('plugin_id', - existing_type=sa.String(length=512), - type_=sa.UUID(), - existing_nullable=False) - - # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py b/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py similarity index 54% rename from api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py rename to api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py index 95e74571d5..ffb4fffe56 100644 --- a/api/migrations/versions/2025_06_18_1506-99310d2c25a6_add_tool_oauth_credentials.py +++ b/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py @@ -1,8 +1,8 @@ -"""add tool oauth credentials +"""tool oauth -Revision ID: 99310d2c25a6 +Revision ID: 71f5020c6470 Revises: 4474872b0ee6 -Create Date: 2025-06-18 15:06:15.261915 +Create Date: 2025-06-24 17:05:43.118647 """ from alembic import op @@ -11,7 +11,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '99310d2c25a6' +revision = '71f5020c6470' down_revision = '4474872b0ee6' branch_labels = None depends_on = None @@ -21,30 +21,30 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tool_oauth_system_clients', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), - sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') ) - op.create_table('tool_oauth_user_clients', + op.create_table('tool_oauth_tenant_clients', sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), - sa.Column('plugin_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), - sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'), - sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client') + sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') ) + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) - batch_op.alter_column('credential_type', - existing_type=sa.VARCHAR(length=255), - type_=sa.String(length=32), - existing_nullable=False, - existing_server_default=sa.text("'api_key'::character varying")) + batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api_key'::character varying"), nullable=False)) batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type']) # ### end Alembic commands ### @@ -52,15 +52,14 @@ def upgrade(): def downgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: - batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique') batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) - batch_op.alter_column('credential_type', - existing_type=sa.String(length=32), - type_=sa.VARCHAR(length=255), - existing_nullable=False, - existing_server_default=sa.text("'api_key'::character varying")) - batch_op.drop_column('default') + batch_op.drop_column('credential_type') + batch_op.drop_column('is_default') + batch_op.drop_column('name') + + with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: + batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) - op.drop_table('tool_oauth_user_clients') + op.drop_table('tool_oauth_tenant_clients') op.drop_table('tool_oauth_system_clients') # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 9e50cec52f..b2979a69dc 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -1,4 +1,3 @@ -import enum import json from datetime import datetime from typing import Any, cast @@ -18,25 +17,6 @@ from .model import Account, App, Tenant from .types import StringUUID -class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key" - OAUTH2 = "oauth2" - - def get_name(self): - return self.value.replace("_", " ").upper() - - def is_editable(self): - return self == ToolProviderCredentialType.API_KEY - - @classmethod - def get_credential_type(cls, credential_type: str) -> "ToolProviderCredentialType": - if credential_type == "api_key": - return cls.API_KEY - elif credential_type == "oauth2": - return cls.OAUTH2 - else: - raise ValueError(f"Invalid credential type: {credential_type}") - # system level tool oauth client params (client_id, client_secret, etc.) class ToolOAuthSystemClient(Base): __tablename__ = "tool_oauth_system_clients" @@ -48,8 +28,6 @@ class ToolOAuthSystemClient(Base): id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - # owner type, e.g., "system", "user" - # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -58,12 +36,12 @@ class ToolOAuthSystemClient(Base): return cast(dict, json.loads(self.encrypted_oauth_params)) -# user level tool oauth client params (client_id, client_secret, etc.) -class ToolOAuthUserClient(Base): - __tablename__ = "tool_oauth_user_clients" +# tenant level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthTenantClient(Base): + __tablename__ = "tool_oauth_tenant_clients" __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_oauth_user_client_pkey"), - db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_user_client"), + db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -71,7 +49,6 @@ class ToolOAuthUserClient(Base): tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) provider: Mapped[str] = mapped_column(db.String(255), nullable=False) - owner_type: Mapped[str] = mapped_column(db.Text, nullable=False) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) # oauth params of the tool provider encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) @@ -80,19 +57,20 @@ class ToolOAuthUserClient(Base): def oauth_params(self) -> dict: return cast(dict, json.loads(self.encrypted_oauth_params)) + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. """ __tablename__ = "tool_builtin_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) - name: Mapped[str] = mapped_column(db.String(256), nullable=False) + name: Mapped[str] = mapped_column( + db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -107,11 +85,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - default: Mapped[bool] = mapped_column( - db.Boolean, nullable=False, server_default=db.text("false") - ) + is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) # credential type, e.g., "api_key", "oauth2" - credential_type: Mapped[str] = mapped_column(db.String(32), nullable=False, server_default=db.text("'api_key'::character varying")) + credential_type: Mapped[str] = mapped_column( + db.String(32), nullable=False, server_default=db.text("'api_key'::character varying") + ) @property def credentials(self) -> dict: @@ -124,13 +102,11 @@ class ApiToolProvider(Base): """ __tablename__ = "tool_api_providers" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), - ) + __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon icon = db.Column(db.String(255), nullable=False) # original schema diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index dcc14a8fad..4d340e2396 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -23,7 +23,7 @@ class OAuthProxyService(BasePluginClient): is used to verify the state, ensuring the request's integrity and authenticity, and mitigating replay attacks. """ - seconds, microseconds = redis_client.time() + seconds, _ = redis_client.time() context_id = str(uuid.uuid4()) data = { "user_id": user_id, @@ -55,7 +55,7 @@ class OAuthProxyService(BasePluginClient): if not data: raise ValueError("context_id is invalid") # check if data is expired - seconds, microseconds = redis_client.time() + seconds, _ = redis_client.time() state = json.loads(data) if state.get("timestamp") < seconds - max_age: raise ValueError("context_id is expired") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 7dc3e4c0f8..6728a19391 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,20 +1,26 @@ import json import logging +import re from pathlib import Path +from sqlalchemy import ColumnExpressionArgument +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity +from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter from extensions.ext_database import db -from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType +from extensions.ext_redis import redis_client +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) @@ -107,7 +113,7 @@ class BuiltinToolManageService: @staticmethod def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None + user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider @@ -119,7 +125,7 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider_name}") try: - if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable(): + if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") @@ -132,18 +138,20 @@ class BuiltinToolManageService: ) # Decrypt and restore original credentials for masked values - credentials = BuiltinToolManageService._decrypt_and_restore_credentials( - provider_controller, tool_configuration, provider, credentials - ) + original_credentials = tool_configuration.decrypt(provider.credentials) + masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + + # check if the credential has changed, save the original credential + for name, value in credentials.items(): + if name in masked_credentials and value == masked_credentials[name]: # type: ignore + credentials[name] = original_credentials[name] # type: ignore # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) else: - raise ValueError( - f"provider {provider_name} is not editable, you can only delete it and add a new one" - ) + raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one") # update name if provided if name is not None and provider.name != name: @@ -151,10 +159,10 @@ class BuiltinToolManageService: db.session.commit() except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, + PluginDaemonClientSideError, + ToolProviderNotFoundError, + ToolNotFoundError, + ToolProviderCredentialValidationError, ) as e: raise ValueError(str(e)) @@ -162,94 +170,136 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( - user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None + user_id: str, + api_type: ToolProviderCredentialType, + tenant_id: str, + provider_name: str, + credentials: dict, + name: str | None = None, ): """ add builtin tool provider """ - if name is None: - name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type) - - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - credential_type=type.value, - credentials=json.dumps(credentials), - name=name, - ) - - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}" + with redis_client.lock(lock_name, timeout=20): + if name is None: + name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) + + provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider_name, + encrypted_credentials=json.dumps(credentials), + credential_type=api_type.value, + name=name, + ) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id - ) - db.session.add(provider) + provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider_name} does not need credentials") + + tool_configuration = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) + + # Encrypt and save the credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller, tool_configuration, provider, credentials, user_id + ) + db.session.add(provider) + db.session.commit() return {"result": "success"} @staticmethod - def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str: - """ - next name = max(provider_names) + 1 - """ - provider_names = db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - credential_type=type.value, - ).all() - if not provider_names: - return f"{type.value} 1" - # OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1 - # if dont have number, then get name and add 1 - for provider_name in provider_names: - if provider_name.provider.startswith(type.value): - return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}" - return f"{type.value} 1" + def get_next_builtin_tool_provider_name( + tenant_id: str, provider_name: str, type: ToolProviderCredentialType + ) -> str: + try: + providers = ( + db.session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider_name, + credential_type=type.value, + ) + .order_by(BuiltinToolProvider.created_at.desc()) + .limit(10) + .all() + ) + # Get the default name pattern + default_pattern = type.get_name() + + # Find all names that match the default pattern: "{default_pattern} {number}" + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] + + for provider in providers: + if provider.name: + match = re.match(pattern, provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # If no default pattern names found, start with 1 + if not numbers: + return f"{default_pattern} 1" + + # Find the next number + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}") + # fallback + return f"{type.get_name()} 1" @staticmethod - def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + def get_builtin_tool_provider_credentials( + tenant_id: str, provider_name: str + ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() - if provider_obj is None: - return {} + if len(providers) == 0: + return [] - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) + provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - credentials = tool_configuration.decrypt(provider_obj.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) + credentials: list[ToolProviderCredentialApiEntity] = [] + for provider in providers: + decrypt_credential = tool_configuration.mask_tool_credentials( + tool_configuration.decrypt(provider.credentials) + ) + credentials.append( + ToolTransformService.convert_builtin_provider_to_credential_api_entity( + provider=provider, + credentials=decrypt_credential, + ) + ) return credentials @staticmethod - def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) if provider_obj is None: raise ValueError(f"you have not added provider {provider_name}") db.session.delete(provider_obj) db.session.commit() - + # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( @@ -267,70 +317,45 @@ class BuiltinToolManageService: """ set default provider """ - # get provider - target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first() - if target_provider is None: - raise ValueError("provider not found") - - # clear default provider - db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - user_id=user_id, - provider=provider, - default=True - ).update({"default": False}) - - # set new default provider - target_provider.default = True - db.session.commit() + with Session(db.engine) as session: + # get provider + target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, default=True + ).update({"default": False}) + + # set new default provider + target_provider.is_default = True + session.commit() return {"result": "success"} - @staticmethod - def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str): - """ - fetch default provider - if there is no explicitly set default provider, return the oldest provider as default - """ - # 1. check if default provider exists - default_provider = db.session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - default=True - ).first() - if default_provider: - return default_provider - - # 2. if no default provider, set the oldest provider as default - oldest_provider = (db.session.query(BuiltinToolProvider) - .filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name) - .order_by(BuiltinToolProvider.created_at) - .first() - ) - if oldest_provider: - return oldest_provider - - raise ValueError(f"no default provider found for {provider_name}") - @staticmethod def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str): """ get builtin tool provider """ - user_client = db.session.query(ToolOAuthUserClient).filter_by( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - enabled=True, - ).first() - - if user_client: - plugin_oauth_config = user_client - else: - plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + with Session(db.engine) as session: + user_client = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + enabled=True, + ) + .first() + ) + if user_client: + plugin_oauth_config = user_client + else: + plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - if plugin_oauth_config: - return plugin_oauth_config + if plugin_oauth_config: + return plugin_oauth_config raise ValueError("no oauth available config found for this plugin") @@ -408,73 +433,69 @@ class BuiltinToolManageService: @staticmethod def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: - provider = (db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.id == credential_id, - ) - .first()) + provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) return provider @staticmethod def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: + """ + This method is used to fetch the builtin provider from the database + 1.if the default provider exists, return the default provider + 2.if the default provider does not exist, return the oldest provider + """ + def _query(provider_filters: list[ColumnExpressionArgument[bool]]): + return ( + db.session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + try: full_provider_name = provider_name - provider_id_entity = GenericProviderID(provider_name) + provider_id_entity = ToolProviderID(provider_name) provider_name = provider_id_entity.provider_name + if provider_id_entity.organization != "langgenius": - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == full_provider_name, - ) - .first() - ) + provider = _query([BuiltinToolProvider.provider == full_provider_name]) else: - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, + provider = _query( + [ (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name), - ) - .first() + | (BuiltinToolProvider.provider == full_provider_name) + ] ) - if provider_obj is None: + if provider is None: return None - provider_obj.provider = GenericProviderID(provider_obj.provider).to_string() - return provider_obj + provider.provider = ToolProviderID(provider.provider).to_string() + return provider except Exception: # it's an old provider without organization - return ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name), - ) - .first() - ) + provider_obj = _query([BuiltinToolProvider.provider == provider_name]) + return provider_obj @staticmethod - def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials): + def _decrypt_and_restore_credentials(tool_configuration, provider, credentials): """ Decrypt original credentials and restore masked values from the input credentials - :param provider_controller: the provider controller :param tool_configuration: the tool configuration encrypter :param provider: the provider object from database :param credentials: the input credentials from user :return: the processed credentials with original values restored """ - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: # type: ignore - credentials[name] = original_credentials[name] # type: ignore return credentials @@ -489,8 +510,9 @@ class BuiltinToolManageService: :param credentials: the credentials to encrypt and save :param user_id: the user id for validation """ - # validate credentials - provider_controller.validate_credentials(user_id, credentials) + if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) + # encrypt credentials encrypted_credentials = tool_configuration.encrypt(credentials) provider.encrypted_credentials = json.dumps(encrypted_credentials) diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 367121125b..b896f6c88f 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -9,12 +9,13 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolParameter, + ToolProviderCredentialType, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController @@ -304,3 +305,16 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + + @staticmethod + def convert_builtin_provider_to_credential_api_entity( + provider: BuiltinToolProvider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=ToolProviderCredentialType.of(provider.credential_type), + is_default=provider.is_default, + credentials=credentials, + ) From fcfaa7ce13552660085ee3ee81780b30d5bac166 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 25 Jun 2025 10:13:41 +0800 Subject: [PATCH 05/15] feat(oauth): plugin oauth service --- api/core/plugin/impl/oauth.py | 18 ++++++-- api/services/plugin/oauth_service.py | 63 ++++++++++++++++++++++++++-- 2 files changed, 73 insertions(+), 8 deletions(-) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index 91774984c8..13873b6ba8 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -1,3 +1,4 @@ +import binascii from collections.abc import Mapping from typing import Any @@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient): provider: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", PluginOAuthAuthorizationUrlResponse, @@ -32,6 +33,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def get_credentials( self, @@ -49,7 +54,7 @@ class OAuthHandler(BasePluginClient): # encode request to raw http request raw_request_bytes = self._convert_request_to_raw_data(request) - return self._request_with_plugin_daemon_response( + response = self._request_with_plugin_daemon_response_stream( "POST", f"plugin/{tenant_id}/dispatch/oauth/get_credentials", PluginOAuthCredentialsResponse, @@ -58,7 +63,8 @@ class OAuthHandler(BasePluginClient): "data": { "provider": provider, "system_credentials": system_credentials, - "raw_request_bytes": raw_request_bytes, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), }, }, headers={ @@ -66,6 +72,10 @@ class OAuthHandler(BasePluginClient): "Content-Type": "application/json", }, ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + def _convert_request_to_raw_data(self, request: Request) -> bytes: """ @@ -79,7 +89,7 @@ class OAuthHandler(BasePluginClient): """ # Start with the request line method = request.method - path = request.path + path = request.full_path protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") raw_data = f"{method} {path} {protocol}\r\n".encode() diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 461247419b..28b955a3d5 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,7 +1,62 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient +from extensions.ext_redis import redis_client + + +class OAuthProxyService(BasePluginClient): + # Default max age for proxy context parameter in seconds + __MAX_AGE__ = 5 * 60 # 5 minutes + + @staticmethod + def create_proxy_context(user_id, tenant_id, plugin_id, provider): + """ + Create a proxy context for an OAuth 2.0 authorization request. + + This parameter is a crucial security measure to prevent Cross-Site Request + Forgery (CSRF) attacks. It works by generating a unique nonce and storing it + in a distributed cache (Redis) along with the user's session context. + + The returned nonce should be included as the 'proxy_context' parameter in the + authorization URL. Upon callback, the `use_proxy_context` method + is used to verify the state, ensuring the request's integrity and authenticity, + and mitigating replay attacks. + """ + seconds, _ = redis_client.time() + context_id = str(uuid.uuid4()) + data = { + "user_id": user_id, + "plugin_id": plugin_id, + "tenant_id": tenant_id, + "provider": provider, + # encode redis time to avoid distribution time skew + "timestamp": seconds, + } + # ignore nonce collision + redis_client.setex( + f"oauth_proxy_context:{context_id}", + OAuthProxyService.__MAX_AGE__, + json.dumps(data), + ) + return context_id -class OAuthService(BasePluginClient): - @classmethod - def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: - return "1234567890" + @staticmethod + def use_proxy_context(context_id, max_age=__MAX_AGE__): + """ + Validate the proxy context parameter. + This checks if the context_id is valid and not expired. + """ + if not context_id: + raise ValueError("context_id is required") + # get data from redis + data = redis_client.getdel(f"oauth_proxy_context:{context_id}") + if not data: + raise ValueError("context_id is invalid") + # check if data is expired + seconds, _ = redis_client.time() + state = json.loads(data) + if state.get("timestamp") < seconds - max_age: + raise ValueError("context_id is expired") + return state From ce4cc54cc9af434dd925c2b1ce1669df0dc237d1 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 25 Jun 2025 14:51:55 +0800 Subject: [PATCH 06/15] feat(oauth): merge tool oauth and remove sequence number branches --- .../console/workspace/tool_providers.py | 10 ++------ api/core/tools/tool_manager.py | 2 +- .../python/examples/github/provider/github.py | 2 +- ...c_merge_tool_oauth_and_remove_sequence_.py | 25 +++++++++++++++++++ .../tools/builtin_tools_manage_service.py | 5 ++-- 5 files changed, 32 insertions(+), 12 deletions(-) create mode 100644 api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index a4839fe8a1..c581a39200 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -676,14 +676,9 @@ class ToolPluginOAuthApi(Resource): if not user.is_admin_or_owner: raise Forbidden() - # check if user client is configured and enabled then using user client - # if user client is not configured then using system client tenant_id = user.current_tenant_id - user_id = user.id - - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - user_id=user_id, provider=provider, plugin_id=plugin_id, ) @@ -727,9 +722,8 @@ class ToolOAuthCallback(Resource): context.get("provider"), ) oauth_handler = OAuthHandler() - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_provider( + plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - user_id=user_id, provider=provider, plugin_id=plugin_id, ) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f25267dbf6..86ffa01667 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -579,7 +579,7 @@ class ToolManager: if "builtin" in filters: def get_builtin_providers(tenant_id): - # according to multi credentials, select the one with is_default=True first, then created_at oldest + # according to multi credentials, select the one with is_default=True first, then created_at oldest # for compatibility with old version sql = """ SELECT DISTINCT ON (tenant_id, provider) id diff --git a/api/dify-plugin-sdks/python/examples/github/provider/github.py b/api/dify-plugin-sdks/python/examples/github/provider/github.py index 36f2f85910..7fb7bd33df 100644 --- a/api/dify-plugin-sdks/python/examples/github/provider/github.py +++ b/api/dify-plugin-sdks/python/examples/github/provider/github.py @@ -64,4 +64,4 @@ class GithubProvider(ToolProvider): if response.status_code != 200: raise ToolProviderCredentialValidationError(response.json().get("message")) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py b/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py new file mode 100644 index 0000000000..a3c51e7e75 --- /dev/null +++ b/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py @@ -0,0 +1,25 @@ +"""merge tool oauth and remove sequence number branches + +Revision ID: 46d46b3f389c +Revises: 0ab65e1cc7fa, 71f5020c6470 +Create Date: 2025-06-25 11:01:55.215896 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '46d46b3f389c' +down_revision = ('0ab65e1cc7fa', '71f5020c6470') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6728a19391..b4f043c647 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -299,7 +299,7 @@ class BuiltinToolManageService: db.session.delete(provider_obj) db.session.commit() - + # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) tool_configuration = ProviderConfigEncrypter( @@ -334,7 +334,7 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str): + def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str): """ get builtin tool provider """ @@ -450,6 +450,7 @@ class BuiltinToolManageService: 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ + def _query(provider_filters: list[ColumnExpressionArgument[bool]]): return ( db.session.query(BuiltinToolProvider) From ba843c26911954b08908bcd46e5f981d3be9e9ba Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 26 Jun 2025 11:44:00 +0800 Subject: [PATCH 07/15] feat(oauth): update api --- .../console/workspace/model_providers.py | 1 + .../console/workspace/tool_providers.py | 6 +- api/core/tools/tool_manager.py | 90 +++++------- .../python/examples/github/provider/github.py | 67 --------- .../tools/builtin_tools_manage_service.py | 128 ++++++------------ api/services/tools/tools_transform_service.py | 2 +- 6 files changed, 84 insertions(+), 210 deletions(-) delete mode 100644 api/dify-plugin-sdks/python/examples/github/provider/github.py diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 32139781b0..ff0fcbda6e 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -35,6 +35,7 @@ class ModelProviderListApi(Resource): model_provider_service = ModelProviderService() provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) + return jsonable_encoder({"data": provider_list}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c581a39200..ceea178214 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -371,12 +371,12 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): + def get(self, provider, credential_type): user = current_user tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) + return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id) class ToolApiProviderSchemaApi(Resource): @@ -789,7 +789,7 @@ api.add_resource( ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema", + "/workspaces/current/tool-provider/builtin///credentials_schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 86ffa01667..bd4a635923 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, - ToolProviderType, -) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ProviderConfigEncrypter, - ToolParameterConfigurationManager, -) +from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -64,8 +55,11 @@ class ToolManager: @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -109,7 +103,12 @@ class ToolManager: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + with contexts.plugin_tool_providers_lock.get(): + # double check plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] @@ -127,25 +126,7 @@ class ToolManager: ) plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool + return controller @classmethod def get_tool_runtime( @@ -563,6 +544,22 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] + @classmethod + def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]: + """ + list all the builtin providers + """ + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + @classmethod def list_providers_from_api( cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral @@ -577,30 +574,13 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - - def get_builtin_providers(tenant_id): - # according to multi credentials, select the one with is_default=True first, then created_at oldest - # for compatibility with old version - sql = """ - SELECT DISTINCT ON (tenant_id, provider) id - FROM tool_builtin_providers - WHERE tenant_id = :tenant_id - ORDER BY tenant_id, provider, is_default DESC, created_at DESC - """ - ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] - return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() - builtin_providers = cls.list_builtin_providers(tenant_id) - # get builtin providers - db_builtin_providers = get_builtin_providers(tenant_id) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - db_provider.provider = str(ToolProviderID(db_provider.provider)) - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) + # key: provider name, value: provider + db_builtin_providers = { + str(ToolProviderID(provider.provider)): provider + for provider in cls.list_default_builtin_providers(tenant_id) + } # append builtin providers for provider in builtin_providers: @@ -612,10 +592,9 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), + db_provider=db_builtin_providers.get(provider.entity.identity.name), decrypt_credentials=False, ) @@ -625,7 +604,6 @@ class ToolManager: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() diff --git a/api/dify-plugin-sdks/python/examples/github/provider/github.py b/api/dify-plugin-sdks/python/examples/github/provider/github.py deleted file mode 100644 index 7fb7bd33df..0000000000 --- a/api/dify-plugin-sdks/python/examples/github/provider/github.py +++ /dev/null @@ -1,67 +0,0 @@ -import secrets -import urllib.parse -from collections.abc import Mapping -from typing import Any - -import requests -from dify_plugin import ToolProvider -from dify_plugin.errors.tool import ToolProviderCredentialValidationError -from werkzeug import Request - - -class GithubProvider(ToolProvider): - _AUTH_URL = "https://github.com/login/oauth/authorize" - _TOKEN_URL = "https://github.com/login/oauth/access_token" - _API_USER_URL = "https://api.github.com/user" - - def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str: - """ - Generate the authorization URL for the Github OAuth. - """ - state = secrets.token_urlsafe(16) - params = { - "client_id": system_credentials["client_id"], - "redirect_uri": system_credentials["redirect_uri"], - "scope": system_credentials.get("scope", "read:user"), - "state": state, - # Optionally: allow_signup, login, etc. - } - return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - - def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]: - """ - Exchange code for access_token. - """ - code = request.args.get("code") - state = request.args.get("state") - if not code: - raise ValueError("No code provided") - # Optionally: validate state here - - data = { - "client_id": system_credentials["client_id"], - "client_secret": system_credentials["client_secret"], - "code": code, - "redirect_uri": system_credentials["redirect_uri"], - } - headers = {"Accept": "application/json"} - response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10) - response_json = response.json() - access_token = response_json.get("access_token") - if not access_token: - raise ValueError(f"Error in GitHub OAuth: {response_json}") - return {"access_token": access_token} - - def _validate_credentials(self, credentials: dict) -> None: - try: - if "access_token" not in credentials or not credentials.get("access_token"): - raise ToolProviderCredentialValidationError("GitHub API Access Token is required.") - headers = { - "Authorization": f"Bearer {credentials['access_token']}", - "Accept": "application/vnd.github+json", - } - response = requests.get(self._API_USER_URL, headers=headers, timeout=10) - if response.status_code != 200: - raise ToolProviderCredentialValidationError(response.json().get("message")) - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index b4f043c647..0137e13b20 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging import re from pathlib import Path +from typing import Optional, Union from sqlalchemy import ColumnExpressionArgument from sqlalchemy.orm import Session @@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError +from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType @@ -40,12 +42,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) @@ -53,7 +50,7 @@ class BuiltinToolManageService: if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + credentials = tool_configuration.decrypt(credentials) result: list[ToolApiEntity] = [] for tool in tools or []: @@ -74,12 +71,7 @@ class BuiltinToolManageService: get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) @@ -87,7 +79,7 @@ class BuiltinToolManageService: if builtin_provider is not None: # get credentials credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + credentials = tool_configuration.decrypt(credentials) entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -100,7 +92,7 @@ class BuiltinToolManageService: return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str): """ list builtin provider credentials schema @@ -123,35 +115,28 @@ class BuiltinToolManageService: if provider is None: raise ValueError(f"you have not added provider {provider_name}") - + try: if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Decrypt and restore original credentials for masked values original_credentials = tool_configuration.decrypt(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: # type: ignore - credentials[name] = original_credentials[name] # type: ignore + for key, value in credentials.items(): + if key in masked_credentials and value == masked_credentials[key]: + credentials[key] = original_credentials[key] # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( provider_controller, tool_configuration, provider, credentials, user_id ) - else: - raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one") # update name if provided if name is not None and provider.name != name: @@ -180,8 +165,8 @@ class BuiltinToolManageService: """ add builtin tool provider """ - lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}" - with redis_client.lock(lock_name, timeout=20): + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}" + with redis_client.lock(lock, timeout=20): if name is None: name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) @@ -198,12 +183,7 @@ class BuiltinToolManageService: if not provider_controller.need_credentials: raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( @@ -268,23 +248,17 @@ class BuiltinToolManageService: return [] provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) credentials: list[ToolProviderCredentialApiEntity] = [] for provider in providers: decrypt_credential = tool_configuration.mask_tool_credentials( tool_configuration.decrypt(provider.credentials) ) - credentials.append( - ToolTransformService.convert_builtin_provider_to_credential_api_entity( - provider=provider, - credentials=decrypt_credential, - ) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, ) + credentials.append(credential_entity) return credentials @staticmethod @@ -292,22 +266,17 @@ class BuiltinToolManageService: """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) - if provider_obj is None: + if tool_provider is None: raise ValueError(f"you have not added provider {provider_name}") - db.session.delete(provider_obj) + db.session.delete(tool_provider) db.session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) tool_configuration.delete_tool_credentials_cache() return {"result": "success"} @@ -334,7 +303,9 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str): + def get_builtin_tool_oauth_client( + tenant_id: str, provider: str, plugin_id: str + ) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]: """ get builtin tool provider """ @@ -350,14 +321,12 @@ class BuiltinToolManageService: .first() ) if user_client: - plugin_oauth_config = user_client - else: - plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - - if plugin_oauth_config: - return plugin_oauth_config + return user_client - raise ValueError("no oauth available config found for this plugin") + system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() + if system_client is None: + raise ValueError("no oauth available client config found for this tool provider") + return system_client @staticmethod def get_builtin_tool_provider_icon(provider: str): @@ -379,9 +348,7 @@ class BuiltinToolManageService: with db.session.no_autoflush: # get all user added providers - db_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: @@ -432,8 +399,8 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: - provider = ( + def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: + provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -444,14 +411,14 @@ class BuiltinToolManageService: return provider @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: + def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ - def _query(provider_filters: list[ColumnExpressionArgument[bool]]): + def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]: return ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) @@ -484,21 +451,16 @@ class BuiltinToolManageService: return provider except Exception: # it's an old provider without organization - provider_obj = _query([BuiltinToolProvider.provider == provider_name]) - return provider_obj + return _query([BuiltinToolProvider.provider == provider_name]) @staticmethod - def _decrypt_and_restore_credentials(tool_configuration, provider, credentials): - """ - Decrypt original credentials and restore masked values from the input credentials - - :param tool_configuration: the tool configuration encrypter - :param provider: the provider object from database - :param credentials: the input credentials from user - :return: the processed credentials with original values restored - """ - - return credentials + def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): + return ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + provider_type=provider_controller.provider_type.value, + provider_identity=provider_controller.entity.identity.name, + ) @staticmethod def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b896f6c88f..66be67dbe6 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -307,7 +307,7 @@ class ToolTransformService: ) @staticmethod - def convert_builtin_provider_to_credential_api_entity( + def convert_builtin_provider_to_credential_entity( provider: BuiltinToolProvider, credentials: dict ) -> ToolProviderCredentialApiEntity: return ToolProviderCredentialApiEntity( From f4f6e41074fcf9b6071b5927065cbd22868e8515 Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 26 Jun 2025 13:27:34 +0800 Subject: [PATCH 08/15] feat(oauth): add oauth redirect_uri parameters --- api/core/plugin/impl/oauth.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index b006bf1d4b..4338c9cf1f 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -15,6 +15,7 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: response = self._request_with_plugin_daemon_response_stream( @@ -25,6 +26,7 @@ class OAuthHandler(BasePluginClient): "user_id": user_id, "data": { "provider": provider, + "redirect_uri": redirect_uri, "system_credentials": system_credentials, }, }, @@ -43,6 +45,7 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], request: Request, ) -> PluginOAuthCredentialsResponse: @@ -61,6 +64,7 @@ class OAuthHandler(BasePluginClient): "user_id": user_id, "data": { "provider": provider, + "redirect_uri": redirect_uri, "system_credentials": system_credentials, # for json serialization "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), From daec82bd44e5ff57348445cd8dca591f6a1e804e Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 27 Jun 2025 13:17:09 +0800 Subject: [PATCH 09/15] feat(oauth): refactor tool provider methods and enhance credential handling --- .../console/workspace/tool_providers.py | 138 +++++---- api/core/tools/builtin_tool/provider.py | 30 +- api/core/tools/entities/tool_entities.py | 14 +- api/core/tools/tool_manager.py | 15 +- api/models/tools.py | 9 +- api/services/plugin/oauth_service.py | 3 + .../tools/api_tools_manage_service.py | 4 +- .../tools/builtin_tools_manage_service.py | 261 +++++++++++------- api/services/tools/tools_transform_service.py | 3 +- 9 files changed, 308 insertions(+), 169 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index ceea178214..5da20c3d29 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,6 +1,6 @@ import io -from flask import redirect, request, send_file +from flask import make_response, redirect, request, send_file from flask_login import current_user from flask_restful import ( Resource, @@ -17,6 +17,7 @@ from controllers.console.wraps import ( setup_required, ) from core.model_runtime.utils.encoders import jsonable_encoder +from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ToolProviderCredentialType from extensions.ext_database import db @@ -127,7 +128,7 @@ class ToolBuiltinProviderAddApi(Resource): return BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=args["credentials"], name=args["name"], api_type=ToolProviderCredentialType.of(args["type"]), @@ -373,10 +374,11 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @account_initialization_required def get(self, provider, credential_type): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id) + return BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, ToolProviderCredentialType.of(credential_type), tenant_id + ) class ToolApiProviderSchemaApi(Resource): @@ -613,15 +615,12 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - - user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder( [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, tenant_id, ) ] @@ -662,13 +661,10 @@ class ToolPluginOAuthApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - parser = reqparse.RequestParser() - parser.add_argument("provider", type=str, required=True, nullable=False, location="args") - parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") - args = parser.parse_args() - provider = args["provider"] - plugin_id = args["plugin_id"] + def get(self, provider): + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name # todo check permission user = current_user @@ -679,63 +675,66 @@ class ToolPluginOAuthApi(Resource): tenant_id = user.current_tenant_id plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - provider=provider, + provider=provider_name, plugin_id=plugin_id, ) oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( - user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) - # todo decrypt oauth params + # TODO decrypt oauth params oauth_params = plugin_oauth_config.oauth_params - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" - oauth_params["redirect_uri"] = redirect_uri - - response = oauth_handler.get_authorization_url( - tenant_id, - user.id, - plugin_id, - provider, + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, system_credentials=oauth_params, ) - return response.model_dump() + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, + ) + return response class ToolOAuthCallback(Resource): @setup_required - def get(self): - args = ( - reqparse.RequestParser() - .add_argument("context_id", type=str, required=True, nullable=False, location="args") - .parse_args() - ) - context_id = args["context_id"] + def get(self, provider): + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + context = OAuthProxyService.use_proxy_context(context_id) if context is None: raise Forbidden("Invalid context_id") - user_id, tenant_id, plugin_id, provider = ( - context.get("user_id"), - context.get("tenant_id"), - context.get("plugin_id"), - context.get("provider"), - ) + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + oauth_handler = OAuthHandler() plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( tenant_id=tenant_id, - provider=provider, + provider=provider_name, plugin_id=plugin_id, ) oauth_params = plugin_oauth_config.oauth_params - redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/tool/callback?context_id={context_id}" - oauth_params["redirect_uri"] = redirect_uri - + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" credentials = oauth_handler.get_credentials( - tenant_id, - user_id, - plugin_id, - provider, + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, system_credentials=oauth_params, request=request, ).credentials @@ -747,12 +746,11 @@ class ToolOAuthCallback(Resource): BuiltinToolManageService.add_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=dict(credentials), - name=provider, api_type=ToolProviderCredentialType.OAUTH2, ) - return redirect(f"{dify_config.CONSOLE_WEB_URL}") + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success") class ToolBuiltinProviderSetDefaultApi(Resource): @@ -768,9 +766,41 @@ class ToolBuiltinProviderSetDefaultApi(Resource): ) +class ToolOAuthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + return BuiltinToolManageService.setup_oauth_custom_client( + tenant_id=user.current_tenant_id, + user_id=user.id, + provider=provider, + client_params=args["client_params"], + ) + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + + # tool oauth -api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/tool") -api.add_resource(ToolOAuthCallback, "/oauth/plugin/tool/callback") +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") +api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") + +api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") @@ -782,14 +812,14 @@ api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/b api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") api.add_resource( - ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//set-default" + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" ) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin///credentials_schema", + "/workspaces/current/tool-provider/builtin//credentials_schema/", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7e..9e3c13849f 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType +from core.tools.entities.tool_entities import ( + OAuthSchema, + ToolEntity, + ToolProviderCredentialType, + ToolProviderEntity, + ToolProviderType, +) from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) + oauth_schema = None + if provider_yaml.get("oauth_schema", None) is not None: + oauth_schema = OAuthSchema( + client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []), + credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []), + ) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, + oauth_schema=oauth_schema, ), ) @@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.tools - def get_credentials_schema(self) -> list[ProviderConfig]: + def get_credentials_schema( + self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY + ) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - if not self.entity.credentials_schema: - return [] - - return self.entity.credentials_schema.copy() + if credential_type == ToolProviderCredentialType.OAUTH2: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + elif credential_type == ToolProviderCredentialType.API_KEY: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + else: + raise ValueError(f"Invalid credential type: {credential_type}") def get_tools(self) -> list[BuiltinTool]: """ diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5094519b6f..922e30b2e0 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -344,10 +344,18 @@ class ToolEntity(BaseModel): return v or [] +class OAuthSchema(BaseModel): + client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth credentials" + ) + + class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -437,7 +445,7 @@ class ToolSelector(BaseModel): class ToolProviderCredentialType(enum.StrEnum): - API_KEY = "api_key" + API_KEY = "api-key" OAUTH2 = "oauth2" def get_name(self): @@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum): elif self == ToolProviderCredentialType.OAUTH2: return "AUTH" else: - return self.value.replace("_", " ").upper() + return self.value.replace("-", " ").upper() def is_editable(self): return self == ToolProviderCredentialType.API_KEY @@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum): @classmethod def of(cls, credential_type: str) -> "ToolProviderCredentialType": type_name = credential_type.lower() - if type_name == "api_key": + if type_name == "api-key": return cls.API_KEY elif type_name == "oauth2": return cls.OAUTH2 diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index bd4a635923..35d4eb0c7e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.tool import ApiTool from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeFrom, + ToolParameter, + ToolProviderCredentialType, + ToolProviderType, +) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager @@ -202,7 +208,12 @@ class ToolManager: credentials = builtin_provider.credentials tool_configuration = ProviderConfigEncrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema( + ToolProviderCredentialType.of(builtin_provider.credential_type) + ) + ], provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) diff --git a/api/models/tools.py b/api/models/tools.py index b2979a69dc..ef2f7bcdde 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -64,7 +64,10 @@ class BuiltinToolProvider(Base): """ __tablename__ = "tool_builtin_providers" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),) + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), + db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), + ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -86,9 +89,9 @@ class BuiltinToolProvider(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) - # credential type, e.g., "api_key", "oauth2" + # credential type, e.g., "api-key", "oauth2" credential_type: Mapped[str] = mapped_column( - db.String(32), nullable=False, server_default=db.text("'api_key'::character varying") + db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") ) @property diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py index 4ad3335ff6..b84dd0afc5 100644 --- a/api/services/plugin/oauth_service.py +++ b/api/services/plugin/oauth_service.py @@ -1,3 +1,6 @@ +import json +import uuid + from core.plugin.impl.base import BasePluginClient from extensions.ext_redis import redis_client diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c4..b429851349 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -446,7 +446,7 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: + def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ @@ -474,7 +474,7 @@ class ApiToolManageService: for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + tenant_id=tenant_id, tool=tool, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 0137e13b20..80ee9b080c 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -4,7 +4,6 @@ import re from pathlib import Path from typing import Optional, Union -from sqlalchemy import ColumnExpressionArgument from sqlalchemy.orm import Session from configs import dify_config @@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.__base.tool_provider import ToolProviderController +from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ProviderConfigEncrypter @@ -29,6 +30,8 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: + __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -42,22 +45,11 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_configuration.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, - credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) @@ -73,7 +65,7 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) + builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) credentials = {} if builtin_provider is not None: @@ -92,16 +84,19 @@ class BuiltinToolManageService: return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str): + def list_builtin_provider_credentials_schema( + provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str + ): """ list builtin provider credentials schema + :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema()) + return jsonable_encoder(provider.get_credentials_schema(credential_type)) @staticmethod def update_builtin_tool_provider( @@ -111,11 +106,11 @@ class BuiltinToolManageService: update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if provider is None: raise ValueError(f"you have not added provider {provider_name}") - + try: if ToolProviderCredentialType.of(provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) @@ -133,10 +128,12 @@ class BuiltinToolManageService: if key in masked_credentials and value == masked_credentials[key]: credentials[key] = original_credentials[key] - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id - ) + provider_controller.validate_credentials(user_id, credentials) + + # encrypt credentials + encrypted_credentials = tool_configuration.encrypt(credentials) + provider.encrypted_credentials = json.dumps(encrypted_credentials) + tool_configuration.delete_tool_credentials_cache() # update name if provided if name is not None and provider.name != name: @@ -158,68 +155,84 @@ class BuiltinToolManageService: user_id: str, api_type: ToolProviderCredentialType, tenant_id: str, - provider_name: str, + provider: str, credentials: dict, name: str | None = None, ): """ add builtin tool provider """ - lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}" + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" with redis_client.lock(lock, timeout=20): - if name is None: - name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) + # check if the provider count is over the limit + provider_count = ( + db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + ) + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") + + # TODO should we get name from oauth authentication? + name = ( + name + if name + else BuiltinToolManageService.generate_builtin_tool_provider_name( + tenant_id, provider, credential_type=api_type + ) + ) - provider = BuiltinToolProvider( + db_provider = BuiltinToolProvider( tenant_id=tenant_id, user_id=user_id, - provider=provider_name, + provider=provider, encrypted_credentials=json.dumps(credentials), credential_type=api_type.value, name=name, ) - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + raise ValueError(f"provider {provider} does not need credentials") tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # Encrypt and save the credentials BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller, tool_configuration, provider, credentials, user_id + provider_controller=provider_controller, + tool_configuration=tool_configuration, + provider=db_provider, + credentials=credentials, + user_id=user_id, ) - db.session.add(provider) + db.session.add(db_provider) db.session.commit() return {"result": "success"} @staticmethod - def get_next_builtin_tool_provider_name( - tenant_id: str, provider_name: str, type: ToolProviderCredentialType + def generate_builtin_tool_provider_name( + tenant_id: str, provider: str, credential_type: ToolProviderCredentialType ) -> str: try: - providers = ( + db_providers = ( db.session.query(BuiltinToolProvider) .filter_by( tenant_id=tenant_id, - provider=provider_name, - credential_type=type.value, + provider=provider, + credential_type=credential_type.value, ) .order_by(BuiltinToolProvider.created_at.desc()) - .limit(10) .all() ) # Get the default name pattern - default_pattern = type.get_name() + default_pattern = f"{credential_type.get_name()}" # Find all names that match the default pattern: "{default_pattern} {number}" pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" numbers = [] - for provider in providers: - if provider.name: - match = re.match(pattern, provider.name.strip()) + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) if match: numbers.append(int(match.group(1))) @@ -231,9 +244,9 @@ class BuiltinToolManageService: max_number = max(numbers) return f"{default_pattern} {max_number + 1}" except Exception as e: - logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}") + logger.warning(f"Error generating next provider name for {provider}: {str(e)}") # fallback - return f"{type.get_name()} 1" + return f"{credential_type.get_name()} 1" @staticmethod def get_builtin_tool_provider_credentials( @@ -242,31 +255,43 @@ class BuiltinToolManageService: """ get builtin tool provider credentials """ - providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() + with db.session.no_autoflush: + providers = ( + db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() + ) - if len(providers) == 0: - return [] + if len(providers) == 0: + return [] - provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - credentials: list[ToolProviderCredentialApiEntity] = [] - for provider in providers: - decrypt_credential = tool_configuration.mask_tool_credentials( - tool_configuration.decrypt(provider.credentials) - ) - credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( - provider=provider, - credentials=decrypt_credential, - ) - credentials.append(credential_entity) - return credentials + default_provider = sorted( + providers, + key=lambda p: ( + not getattr(p, "is_default", False), + getattr(p, "created_at", None) or 0, + ), + )[0] + + default_provider.is_default = True + provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + credentials: list[ToolProviderCredentialApiEntity] = [] + for provider in providers: + decrypt_credential = tool_configuration.mask_tool_credentials( + tool_configuration.decrypt(provider.credentials) + ) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, + ) + credentials.append(credential_entity) + return credentials @staticmethod def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): """ delete tool provider """ - tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) + tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if tool_provider is None: raise ValueError(f"you have not added provider {provider_name}") @@ -387,7 +412,6 @@ class BuiltinToolManageService: ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, - credentials=user_builtin_provider.original_credentials, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) @@ -399,7 +423,7 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( @@ -411,48 +435,63 @@ class BuiltinToolManageService: return provider @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ + with Session(db.engine) as session: + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + + if provider_id_entity.organization != "langgenius": + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == full_provider_name, + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + else: + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) - def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]: - return ( - db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) - .order_by( - BuiltinToolProvider.is_default.desc(), # default=True first - BuiltinToolProvider.created_at.asc(), # oldest first - ) - .first() - ) - - try: - full_provider_name = provider_name - provider_id_entity = ToolProviderID(provider_name) - provider_name = provider_id_entity.provider_name - - if provider_id_entity.organization != "langgenius": - provider = _query([BuiltinToolProvider.provider == full_provider_name]) - else: - provider = _query( - [ - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name) - ] + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + return ( + session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() ) - if provider is None: - return None - - provider.provider = ToolProviderID(provider.provider).to_string() - return provider - except Exception: - # it's an old provider without organization - return _query([BuiltinToolProvider.provider == provider_name]) - @staticmethod def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): return ProviderConfigEncrypter( @@ -463,7 +502,13 @@ class BuiltinToolManageService: ) @staticmethod - def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): + def _encrypt_and_save_credentials( + provider_controller: BuiltinToolProviderController | PluginToolProviderController, + tool_configuration: ProviderConfigEncrypter, + provider: BuiltinToolProvider, + credentials: dict, + user_id: str, + ): """ Validate and encrypt credentials, then save to database @@ -480,3 +525,25 @@ class BuiltinToolManageService: encrypted_credentials = tool_configuration.encrypt(credentials) provider.encrypted_credentials = json.dumps(encrypted_credentials) tool_configuration.delete_tool_credentials_cache() + + @staticmethod + def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict): + """ + setup oauth custom client + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + + # Validate and encrypt credentials + BuiltinToolManageService._encrypt_and_save_credentials( + provider_controller=provider_controller, + tool_configuration=tool_configuration, + provider=None, # No need to save in DB + credentials=client_params, + user_id=user_id, + ) + + return {"result": "success"} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 66be67dbe6..160352c4c0 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -255,7 +255,6 @@ class ToolTransformService: def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, - credentials: dict | None = None, labels: list[str] | None = None, ) -> ToolApiEntity: """ @@ -265,7 +264,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials or {}, + credentials= {}, tenant_id=tenant_id, ) ) From 7951a1c4dff0a6cbeca7244566b1cf6287c40175 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 2 Jul 2025 10:04:57 +0800 Subject: [PATCH 10/15] refactor(tool): implement multi provider credentials support --- .../console/workspace/tool_providers.py | 5 +- api/core/helper/provider_cache.py | 77 ++++++++ api/core/helper/tool_provider_cache.py | 51 ----- .../plugin/backwards_invocation/encrypt.py | 6 +- api/core/tools/builtin_tool/provider.py | 34 +++- api/core/tools/tool_manager.py | 32 ++-- api/core/tools/utils/configuration.py | 72 ++++--- .../tools/api_tools_manage_service.py | 18 +- .../tools/builtin_tools_manage_service.py | 179 +++++++++--------- api/services/tools/tools_transform_service.py | 40 ++-- 10 files changed, 296 insertions(+), 218 deletions(-) create mode 100644 api/core/helper/provider_cache.py delete mode 100644 api/core/helper/tool_provider_cache.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5da20c3d29..090d5f3cee 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -82,7 +82,7 @@ class ToolBuiltinProviderInfoApi(Resource): user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) class ToolBuiltinProviderDeleteApi(Resource): @@ -159,7 +159,7 @@ class ToolBuiltinProviderUpdateApi(Resource): result = BuiltinToolManageService.update_builtin_tool_provider( user_id=user_id, tenant_id=tenant_id, - provider_name=provider, + provider=provider, credentials=args["credentials"], credential_id=args["credential_id"], name=args["name"], @@ -782,7 +782,6 @@ class ToolOAuthCustomClient(Resource): return BuiltinToolManageService.setup_oauth_custom_client( tenant_id=user.current_tenant_id, - user_id=user.id, provider=provider, client_params=args["client_params"], ) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py new file mode 100644 index 0000000000..3e70ea5341 --- /dev/null +++ b/api/core/helper/provider_cache.py @@ -0,0 +1,77 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCache(ABC): + """Base class for provider credentials cache""" + + def __init__(self, **kwargs): + self.cache_key = self._generate_cache_key(**kwargs) + + @abstractmethod + def _generate_cache_key(self, **kwargs) -> str: + """Generate cache key based on subclass implementation""" + pass + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + cached_credentials = redis_client.get(self.cache_key) + if cached_credentials: + try: + cached_credentials = cached_credentials.decode("utf-8") + return dict(json.loads(cached_credentials)) + except JSONDecodeError: + return None + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + redis_client.setex(self.cache_key, 86400, json.dumps(config)) + + def delete(self) -> None: + """Delete cached provider credentials""" + redis_client.delete(self.cache_key) + + +class GenericProviderCredentialsCache(ProviderCredentialsCache): + """Cache for generic provider credentials""" + + def __init__(self, tenant_id: str, identity_id: str): + super().__init__(tenant_id=tenant_id, identity_id=identity_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + identity_id = kwargs["identity_id"] + return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" + +class ToolProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + +class NoOpProviderCredentialCache: + """No-op provider credential cache""" + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + pass + + def delete(self) -> None: + """Delete cached provider credentials""" + pass diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index 2e4a04c579..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from extensions.ext_redis import redis_client - - -class ToolProviderCredentialsCacheType(Enum): - PROVIDER = "tool_provider" - ENDPOINT = "endpoint" - - -class ToolProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode("utf-8") - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return dict(cached_provider_credentials) - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a0..bfe9ffa4b0 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,12 +1,12 @@ from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_generic_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter = ProviderConfigEncrypter( + encrypter, cache = create_generic_encrypter( tenant_id=tenant.id, config=payload.config, provider_type=payload.namespace, @@ -22,7 +22,7 @@ class PluginEncrypter: "data": encrypter.decrypt(payload.data), } elif payload.opt == "clear": - encrypter.delete_tool_credentials_cache() + cache.delete() return { "data": {}, } diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 9e3c13849f..53affe9e97 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -105,20 +105,34 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.tools - def get_credentials_schema( - self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY - ) -> list[ProviderConfig]: + def get_credentials_schema(self) -> list[ProviderConfig]: """ returns the credentials schema of the provider :return: the credentials schema """ - if credential_type == ToolProviderCredentialType.OAUTH2: + return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value) + + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider + + :param credential_type: the type of the credential + :return: the credentials schema of the provider + """ + if credential_type == ToolProviderCredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - elif credential_type == ToolProviderCredentialType.API_KEY: + if credential_type == ToolProviderCredentialType.API_KEY.value: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] - else: - raise ValueError(f"Invalid credential type: {credential_type}") + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + returns the oauth client schema of the provider + + :return: the oauth client schema + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] def get_tools(self) -> list[BuiltinTool]: """ @@ -141,7 +155,11 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + return ( + self.entity.credentials_schema is not None + and len(self.entity.credentials_schema) != 0 + or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0) + ) @property def provider_type(self) -> ToolProviderType: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 35d4eb0c7e..e9423a6c49 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -38,12 +39,16 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager +from core.tools.utils.configuration import ( + ProviderConfigEncrypter, + ToolParameterConfigurationManager, + create_encrypter, + create_generic_encrypter, +) from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -206,19 +211,18 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() - for x in provider_controller.get_credentials_schema( - ToolProviderCredentialType.of(builtin_provider.credential_type) - ) + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) ], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) return cast( BuiltinTool, @@ -235,22 +239,18 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], provider_type=api_provider.provider_type.value, provider_identity=api_provider.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(credentials), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) @@ -730,7 +730,7 @@ class ToolManager: ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + tool_configuration = ProviderConfigEncrypter.create_cached( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], provider_type=controller.provider_type.value, diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 6a5fba65bd..2b64703321 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,12 +1,10 @@ from copy import deepcopy -from typing import Any - -from pydantic import BaseModel +from typing import Any, Optional, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter +from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -14,11 +12,38 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigEncrypter(BaseModel): +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: tenant_id: str config: list[BasicProviderConfig] - provider_type: str - provider_identity: str + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: """ @@ -72,18 +97,13 @@ class ProviderConfigEncrypter(BaseModel): return data - def decrypt(self, data: dict[str, str]) -> dict[str, str]: + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: """ decrypt tool credentials with tenant id return a deep copy of credentials with decrypted values """ - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() + cached_credentials = self.provider_config_cache.get() if cached_credentials: return cached_credentials data = self._deep_copy(data) @@ -104,16 +124,24 @@ class ProviderConfigEncrypter(BaseModel): except Exception: pass - cache.set(data) + self.provider_config_cache.set(data) return data - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() + +def create_encrypter( + tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache +): + return ProviderConfigEncrypter( + tenant_id=tenant_id, config=config, provider_config_cache=cache + ), cache + + +def create_generic_encrypter( + tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str +): + cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") + encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) + return encrypt, cache class ToolParameterConfigurationManager: diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index b429851349..ff84b4318b 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -297,28 +297,28 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ProviderConfigEncrypter( + encrypter, cache = create_generic_encrypter( tenant_id=tenant_id, config=list(provider_controller.get_credentials_schema()), provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) + credentials = encrypter.encrypt(credentials) provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() # delete cache - tool_configuration.delete_tool_credentials_cache() + cache.delete() # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) @@ -416,15 +416,15 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=tenant_id, config=list(provider_controller.get_credentials_schema()), provider_type=provider_controller.provider_type.value, provider_identity=provider_controller.entity.identity.name, ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 80ee9b080c..17c1a4b421 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -8,19 +8,18 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper.position_helper import is_filtered +from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError -from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError -from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient @@ -58,20 +57,15 @@ class BuiltinToolManageService: return result @staticmethod - def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) # check if user has added the provider builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_configuration.decrypt(credentials) + if builtin_provider is None: + raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -80,7 +74,6 @@ class BuiltinToolManageService: ) entity.original_credentials = {} - return entity @staticmethod @@ -96,32 +89,34 @@ class BuiltinToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema(credential_type)) + return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type)) @staticmethod def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None + user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider """ # get if the provider exists - provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) + db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) - if provider is None: - raise ValueError(f"you have not added provider {provider_name}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") try: - if ToolProviderCredentialType.of(provider.credential_type).is_editable(): - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) + if ToolProviderCredentialType.of(db_provider.credential_type).is_editable(): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") + raise ValueError(f"provider {provider} does not need credentials") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) # Decrypt and restore original credentials for masked values - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(db_provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for key, value in credentials.items(): @@ -131,13 +126,13 @@ class BuiltinToolManageService: provider_controller.validate_credentials(user_id, credentials) # encrypt credentials - encrypted_credentials = tool_configuration.encrypt(credentials) - provider.encrypted_credentials = json.dumps(encrypted_credentials) - tool_configuration.delete_tool_credentials_cache() + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + + cache.delete() # update name if provided - if name is not None and provider.name != name: - provider.name = name + if name is not None and db_provider.name != name: + db_provider.name = name db.session.commit() except ( @@ -176,7 +171,7 @@ class BuiltinToolManageService: name if name else BuiltinToolManageService.generate_builtin_tool_provider_name( - tenant_id, provider, credential_type=api_type + tenant_id=tenant_id, provider=provider, credential_type=api_type ) ) @@ -193,20 +188,35 @@ class BuiltinToolManageService: if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - - # Encrypt and save the credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller=provider_controller, - tool_configuration=tool_configuration, - provider=db_provider, - credentials=credentials, - user_id=user_id, + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller ) + + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) + + cache.delete() db.session.add(db_provider) db.session.commit() return {"result": "success"} + @staticmethod + def create_tool_encrypter( + tenant_id: str, + db_provider: BuiltinToolProvider, + provider: str, + provider_controller: BuiltinToolProviderController, + ): + encrypter, cache = create_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) + ], + cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), + ) + return encrypter, cache + @staticmethod def generate_builtin_tool_provider_name( tenant_id: str, provider: str, credential_type: ToolProviderCredentialType @@ -273,12 +283,13 @@ class BuiltinToolManageService: default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, default_provider, default_provider.provider, provider_controller + ) + credentials: list[ToolProviderCredentialApiEntity] = [] for provider in providers: - decrypt_credential = tool_configuration.mask_tool_credentials( - tool_configuration.decrypt(provider.credentials) - ) + decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( provider=provider, credentials=decrypt_credential, @@ -287,22 +298,24 @@ class BuiltinToolManageService: return credentials @staticmethod - def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str): + def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if tool_provider is None: - raise ValueError(f"you have not added provider {provider_name}") + raise ValueError(f"you have not added provider {provider}") db.session.delete(tool_provider) db.session.commit() # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) - tool_configuration.delete_tool_credentials_cache() + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, tool_provider, provider, provider_controller + ) + cache.delete() return {"result": "success"} @@ -493,57 +506,35 @@ class BuiltinToolManageService: ) @staticmethod - def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController): - return ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - @staticmethod - def _encrypt_and_save_credentials( - provider_controller: BuiltinToolProviderController | PluginToolProviderController, - tool_configuration: ProviderConfigEncrypter, - provider: BuiltinToolProvider, - credentials: dict, - user_id: str, - ): - """ - Validate and encrypt credentials, then save to database - - :param provider_controller: the provider controller - :param tool_configuration: the tool configuration encrypter - :param provider: the provider object from database - :param credentials: the credentials to encrypt and save - :param user_id: the user id for validation - """ - if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed(): - provider_controller.validate_credentials(user_id, credentials) - - # encrypt credentials - encrypted_credentials = tool_configuration.encrypt(credentials) - provider.encrypted_credentials = json.dumps(encrypted_credentials) - tool_configuration.delete_tool_credentials_cache() - - @staticmethod - def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict): + def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict): """ setup oauth custom client """ - provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - if not provider_controller: - raise ToolProviderNotFoundError(f"Provider {provider} not found") + with Session(db.engine) as session: + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") - tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller) + if not isinstance(provider_controller, BuiltinToolProviderController): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - # Validate and encrypt credentials - BuiltinToolManageService._encrypt_and_save_credentials( - provider_controller=provider_controller, - tool_configuration=tool_configuration, - provider=None, # No need to save in DB - credentials=client_params, - user_id=user_id, - ) + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + # encrypt credentials + encrypted_credentials = encrypter.encrypt(client_params) + session.add( + ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + enabled=True, + encrypted_oauth_params=json.dumps(encrypted_credentials), + ) + ) + session.commit() return {"result": "success"} diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 160352c4c0..1c3ef3d48c 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,6 +5,7 @@ from typing import Optional, Union, cast from yarl import URL from configs import dify_config +from core.helper.provider_cache import ToolProviderCredentialsCache from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -19,7 +20,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.configuration import create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -109,7 +110,14 @@ class ToolTransformService: result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + schema = { + x.to_basic_provider_config().name: x + for x in provider_controller.get_credentials_schema_by_type( + ToolProviderCredentialType.of(db_provider.credential_type) + if db_provider + else ToolProviderCredentialType.API_KEY + ) + } for name, value in schema.items(): if result.masked_credentials: @@ -126,15 +134,23 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type( + ToolProviderCredentialType.of(db_provider.credential_type) + ) + ], + cache=ToolProviderCredentialsCache( + tenant_id=db_provider.tenant_id, + provider=db_provider.provider, + credential_id=db_provider.id, + ), ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -236,7 +252,7 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_generic_encrypter( tenant_id=db_provider.tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], provider_type=provider_controller.provider_type.value, @@ -244,8 +260,8 @@ class ToolTransformService: ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials @@ -264,7 +280,7 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials= {}, + credentials={}, tenant_id=tenant_id, ) ) From 826bf25abf839702dbb98800b7ab26598df8a2fb Mon Sep 17 00:00:00 2001 From: efrey kong Date: Wed, 2 Jul 2025 14:43:01 +0800 Subject: [PATCH 11/15] Fix: prevent SQL errors when metadata filter Constant value is None or blank (#21803) --- api/core/rag/retrieval/dataset_retrieval.py | 3 +++ .../nodes/knowledge_retrieval/knowledge_retrieval_node.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 38c0b540d5..3fca48be22 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1010,6 +1010,9 @@ class DatasetRetrieval: def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): + if value is None: + return + key = f"{metadata_name}_{sequence}" key_value = f"{metadata_name}_{sequence}_value" match condition: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0b9e98f28a..b34d62d669 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -490,6 +490,9 @@ class KnowledgeRetrievalNode(LLMNode): def _process_metadata_filter_func( self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list ): + if value is None: + return + key = f"{metadata_name}_{sequence}" key_value = f"{metadata_name}_{sequence}_value" match condition: From 6ef1e017df8871654463d680c11bbdc759a245a8 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 2 Jul 2025 14:58:44 +0800 Subject: [PATCH 12/15] feat(oauth): add support for retrieving credential info and OAuth client schema --- .../console/workspace/tool_providers.py | 42 +++++++++++++++++-- api/core/tools/builtin_tool/provider.py | 11 +++++ api/core/tools/entities/api_entities.py | 5 +++ .../tools/builtin_tools_manage_service.py | 33 +++++++++++++-- 4 files changed, 85 insertions(+), 6 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 090d5f3cee..e94fcc195f 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -376,8 +376,10 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): user = current_user tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema( - provider, ToolProviderCredentialType.of(credential_type), tenant_id + return jsonable_encoder( + BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, ToolProviderCredentialType.of(credential_type), tenant_id + ) ) @@ -795,6 +797,33 @@ class ToolOAuthCustomClient(Resource): ) +class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + ) + + +class ToolBuiltinProviderGetCredentialInfoApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tenant_id = current_user.current_tenant_id + + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credential_info( + tenant_id=tenant_id, + provider=provider, + ) + ) + + # tool oauth api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") @@ -813,12 +842,19 @@ api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provide api.add_resource( ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" ) +api.add_resource( + ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema/", + "/workspaces/current/tool-provider/builtin//credential/schema/", +) +api.add_resource( + ToolBuiltinProviderGetOauthClientSchemaApi, + "/workspaces/current/tool-provider/builtin//oauth/client-schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 53affe9e97..ce85a37501 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -134,6 +134,17 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + def get_supported_credential_types(self) -> list[str]: + """ + returns the credential support type of the provider + """ + types = [] + if self.entity.credentials_schema is not None: + types.append(ToolProviderCredentialType.API_KEY.value) + if self.entity.oauth_schema is not None: + types.append(ToolProviderCredentialType.OAUTH2.value) + return types + def get_tools(self) -> list[BuiltinTool]: """ returns a list of tools that the provider can provide diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index eaadd4d214..ebb503a8b3 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -81,3 +81,8 @@ class ToolProviderCredentialApiEntity(BaseModel): default=False, description="Whether the credential is the default credential for the provider in the workspace" ) credentials: dict = Field(description="The credentials of the provider") + + +class ToolProviderCredentialInfoApiEntity(BaseModel): + supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") \ No newline at end of file diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 17c1a4b421..2abb234a83 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -9,12 +9,16 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity +from core.tools.entities.api_entities import ( + ToolApiEntity, + ToolProviderApiEntity, + ToolProviderCredentialApiEntity, + ToolProviderCredentialInfoApiEntity, +) from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager @@ -31,6 +35,14 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + @staticmethod + def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str): + """ + get builtin tool provider oauth client schema + """ + provider = ToolManager.get_builtin_provider(provider_name, tenant_id) + return provider.get_oauth_client_schema() + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -89,7 +101,7 @@ class BuiltinToolManageService: :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type)) + return provider.get_credentials_schema_by_type(credential_type) @staticmethod def update_builtin_tool_provider( @@ -297,6 +309,21 @@ class BuiltinToolManageService: credentials.append(credential_entity) return credentials + @staticmethod + def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity: + """ + get builtin tool provider credential info + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + supported_credential_types = provider_controller.get_supported_credential_types() + credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) + credential_info = ToolProviderCredentialInfoApiEntity( + supported_credential_types=supported_credential_types, + credentials=credentials, + ) + + return credential_info + @staticmethod def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ From 988a76066d67ea92ec1a1e28e3c65be4af4d8986 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 2 Jul 2025 20:19:04 +0800 Subject: [PATCH 13/15] feat(oauth): enhance OAuth client handling and add custom client support --- .../console/workspace/tool_providers.py | 37 ++-- api/core/tools/entities/api_entities.py | 5 +- .../tools/builtin_tools_manage_service.py | 166 ++++++++++++++---- 3 files changed, 150 insertions(+), 58 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e94fcc195f..c782a4c37f 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -675,18 +675,17 @@ class ToolPluginOAuthApi(Resource): raise Forbidden() tenant_id = user.current_tenant_id - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( + oauth_client_params = BuiltinToolManageService.get_oauth_client( tenant_id=tenant_id, - provider=provider_name, - plugin_id=plugin_id, + provider=provider ) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) - # TODO decrypt oauth params - oauth_params = plugin_oauth_config.oauth_params redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" authorization_url_response = oauth_handler.get_authorization_url( tenant_id=tenant_id, @@ -694,7 +693,7 @@ class ToolPluginOAuthApi(Resource): plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, - system_credentials=oauth_params, + system_credentials=oauth_client_params, ) response = make_response(jsonable_encoder(authorization_url_response)) response.set_cookie( @@ -724,12 +723,10 @@ class ToolOAuthCallback(Resource): user_id, tenant_id = context.get("user_id"), context.get("tenant_id") oauth_handler = OAuthHandler() - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( - tenant_id=tenant_id, - provider=provider_name, - plugin_id=plugin_id, - ) - oauth_params = plugin_oauth_config.oauth_params + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" credentials = oauth_handler.get_credentials( tenant_id=tenant_id, @@ -737,7 +734,7 @@ class ToolOAuthCallback(Resource): plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, - system_credentials=oauth_params, + system_credentials=oauth_client_params, request=request, ).credentials @@ -774,7 +771,8 @@ class ToolOAuthCustomClient(Resource): @account_initialization_required def post(self, provider): parser = reqparse.RequestParser() - parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json") + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") args = parser.parse_args() user = current_user @@ -782,18 +780,21 @@ class ToolOAuthCustomClient(Resource): if not user.is_admin_or_owner: raise Forbidden() - return BuiltinToolManageService.setup_oauth_custom_client( + return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=user.current_tenant_id, provider=provider, - client_params=args["client_params"], + client_params=args.get("client_params", {}), + enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), ) @setup_required @login_required @account_initialization_required def get(self, provider): - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=current_user.current_tenant_id, provider_name=provider + return jsonable_encoder( + BuiltinToolManageService.get_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) ) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ebb503a8b3..483fbe13d7 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -85,4 +85,7 @@ class ToolProviderCredentialApiEntity(BaseModel): class ToolProviderCredentialInfoApiEntity(BaseModel): supported_credential_types: list[str] = Field(description="The supported credential types of the provider") - credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") \ No newline at end of file + is_oauth_custom_client_enabled: bool = Field( + default=False, description="Whether the OAuth custom client is enabled for the provider" + ) + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 2abb234a83..4058e576f0 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,7 +2,7 @@ import json import logging import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional from sqlalchemy.orm import Session @@ -21,6 +21,7 @@ from core.tools.entities.api_entities import ( ) from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import create_encrypter @@ -41,7 +42,12 @@ class BuiltinToolManageService: get builtin tool provider oauth client schema """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return provider.get_oauth_client_schema() + return { + "schema": provider.get_oauth_client_schema(), + "is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ), + } @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: @@ -139,7 +145,7 @@ class BuiltinToolManageService: # encrypt credentials db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) - + cache.delete() # update name if provided @@ -279,20 +285,16 @@ class BuiltinToolManageService: """ with db.session.no_autoflush: providers = ( - db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() + db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider_name) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .all() ) if len(providers) == 0: return [] - default_provider = sorted( - providers, - key=lambda p: ( - not getattr(p, "is_default", False), - getattr(p, "created_at", None) or 0, - ), - )[0] - + default_provider = providers[0] default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) encrypter, cache = BuiltinToolManageService.create_tool_encrypter( @@ -319,6 +321,7 @@ class BuiltinToolManageService: credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) credential_info = ToolProviderCredentialInfoApiEntity( supported_credential_types=supported_credential_types, + is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), credentials=credentials, ) @@ -368,30 +371,61 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_oauth_client( - tenant_id: str, provider: str, plugin_id: str - ) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]: + def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: + """ + check if oauth custom client is enabled + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + return user_client is not None and user_client.enabled + + @staticmethod + def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None: """ get builtin tool provider """ - with Session(db.engine) as session: - user_client = ( + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, enabled=True, ) .first() ) + oauth_params: dict[str, Any] | None = None if user_client: - return user_client + oauth_params = encrypter.decrypt(user_client.oauth_params) + return oauth_params + + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + if system_client: + oauth_params = encrypter.decrypt(system_client.oauth_params) - system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - if system_client is None: - raise ValueError("no oauth available client config found for this tool provider") - return system_client + return oauth_params @staticmethod def get_builtin_tool_provider_icon(provider: str): @@ -533,12 +567,79 @@ class BuiltinToolManageService: ) @staticmethod - def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict): + def save_custom_oauth_client_params( + tenant_id: str, + provider: str, + client_params: Optional[dict] = None, + enable_oauth_custom_client: Optional[bool] = None, + ): """ setup oauth custom client """ + if client_params is None and enable_oauth_custom_client is None: + return {"result": "success"} + + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + with Session(db.engine) as session: + custom_client_params = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + + # if the record does not exist, create a basic record + if custom_client_params is None: + custom_client_params = ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + session.add(custom_client_params) + + if client_params is not None: + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params)) + + if enable_oauth_custom_client is not None: + custom_client_params.enabled = enable_oauth_custom_client + + session.commit() + return {"result": "success"} + + @staticmethod + def get_custom_oauth_client_params(tenant_id: str, provider: str): + """ + get custom oauth client params + """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) + custom_oauth_client_params: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + if custom_oauth_client_params is None: + return {} + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller: raise ToolProviderNotFoundError(f"Provider {provider} not found") @@ -551,17 +652,4 @@ class BuiltinToolManageService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - - # encrypt credentials - encrypted_credentials = encrypter.encrypt(client_params) - session.add( - ToolOAuthTenantClient( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, - enabled=True, - encrypted_oauth_params=json.dumps(encrypted_credentials), - ) - ) - session.commit() - return {"result": "success"} + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) From 9ce6f34dc4ba75ea8a238f5ebd175d1fe529a125 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 14:25:33 +0800 Subject: [PATCH 14/15] feat(oauth): add multi credentials support --- api/core/plugin/impl/tool.py | 4 +- api/core/tools/__base/tool_runtime.py | 3 +- api/core/tools/plugin_tool/tool.py | 1 + api/core/tools/tool_manager.py | 48 +++++++++++++++--------- api/core/workflow/nodes/tool/entities.py | 1 + api/services/app_dsl_service.py | 5 +++ 6 files changed, 42 insertions(+), 20 deletions(-) diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..f84e8c6c5e 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: ToolProviderCredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb77..51e339bed1 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..aef2677c36 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e9423a6c49..7e37192979 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from yarl import URL @@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, + ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError @@ -148,6 +149,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + credential_id: Optional[str] = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -158,6 +160,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -185,19 +188,31 @@ class ToolManager: if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + if credential_id: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found") + else: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") else: builtin_provider = ( db.session.query(BuiltinToolProvider) @@ -209,8 +224,6 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[ @@ -221,15 +234,13 @@ class ToolManager: tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id ), ) - - decrypted_credentials = encrypter.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(builtin_provider.credentials), + credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -362,6 +373,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) runtime_parameters = {} parameters = tool_runtime.get_merged_runtime_parameters() diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 21023d4ab7..2ce6ac3fc1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..f53048a690 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -582,6 +582,11 @@ class AppDslService: cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if node.get("data", {}).get("type", "") == NodeType.TOOL.value: + node["data"]["credential_id"] = None + + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ From 9b25b7a8d8420bae1cb43dc47a75b2c1e71d31c4 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 14:29:17 +0800 Subject: [PATCH 15/15] feat(oauth): rename ToolProviderCredentialType to CredentialType for consistency --- api/controllers/console/workspace/tool_providers.py | 10 +++++----- api/core/plugin/impl/tool.py | 4 ++-- api/core/tools/__base/tool_runtime.py | 4 ++-- api/core/tools/builtin_tool/provider.py | 12 ++++++------ api/core/tools/entities/api_entities.py | 4 ++-- api/core/tools/entities/tool_entities.py | 12 ++++++------ api/core/tools/tool_manager.py | 4 ++-- api/services/tools/builtin_tools_manage_service.py | 10 +++++----- api/services/tools/tools_transform_service.py | 10 +++++----- 9 files changed, 35 insertions(+), 35 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c782a4c37f..f71cf34d4a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -19,7 +19,7 @@ from controllers.console.wraps import ( from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler -from core.tools.entities.tool_entities import ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -122,7 +122,7 @@ class ToolBuiltinProviderAddApi(Resource): parser.add_argument("type", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if args["type"] not in ToolProviderCredentialType.values(): + if args["type"] not in CredentialType.values(): raise ValueError(f"Invalid credential type: {args['type']}") return BuiltinToolManageService.add_builtin_tool_provider( @@ -131,7 +131,7 @@ class ToolBuiltinProviderAddApi(Resource): provider=provider, credentials=args["credentials"], name=args["name"], - api_type=ToolProviderCredentialType.of(args["type"]), + api_type=CredentialType.of(args["type"]), ) @@ -378,7 +378,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): return jsonable_encoder( BuiltinToolManageService.list_builtin_provider_credentials_schema( - provider, ToolProviderCredentialType.of(credential_type), tenant_id + provider, CredentialType.of(credential_type), tenant_id ) ) @@ -747,7 +747,7 @@ class ToolOAuthCallback(Resource): tenant_id=tenant_id, provider=provider, credentials=dict(credentials), - api_type=ToolProviderCredentialType.OAUTH2, + api_type=CredentialType.OAUTH2, ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success") diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index f84e8c6c5e..04225f95ee 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): @@ -78,7 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], - credential_type: ToolProviderCredentialType, + credential_type: CredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 51e339bed1..1068b07062 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom class ToolRuntime(BaseModel): @@ -17,7 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) - credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY + credential_type: Optional[CredentialType] = CredentialType.API_KEY runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index ce85a37501..f9a03e40ae 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -8,9 +8,9 @@ from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ( + CredentialType, OAuthSchema, ToolEntity, - ToolProviderCredentialType, ToolProviderEntity, ToolProviderType, ) @@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value) + return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: """ @@ -120,9 +120,9 @@ class BuiltinToolProviderController(ToolProviderController): :param credential_type: the type of the credential :return: the credentials schema of the provider """ - if credential_type == ToolProviderCredentialType.OAUTH2.value: + if credential_type == CredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - if credential_type == ToolProviderCredentialType.API_KEY.value: + if credential_type == CredentialType.API_KEY.value: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] raise ValueError(f"Invalid credential type: {credential_type}") @@ -140,9 +140,9 @@ class BuiltinToolProviderController(ToolProviderController): """ types = [] if self.entity.credentials_schema is not None: - types.append(ToolProviderCredentialType.API_KEY.value) + types.append(CredentialType.API_KEY.value) if self.entity.oauth_schema is not None: - types.append(ToolProviderCredentialType.OAUTH2.value) + types.append(CredentialType.OAUTH2.value) return types def get_tools(self) -> list[BuiltinTool]: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 483fbe13d7..687883ce19 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType +from core.tools.entities.tool_entities import CredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -76,7 +76,7 @@ class ToolProviderCredentialApiEntity(BaseModel): id: str = Field(description="The unique id of the credential") name: str = Field(description="The name of the credential") provider: str = Field(description="The provider of the credential") - credential_type: ToolProviderCredentialType = Field(description="The type of the credential") + credential_type: CredentialType = Field(description="The type of the credential") is_default: bool = Field( default=False, description="Whether the credential is the default credential for the provider in the workspace" ) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f5cb768205..aad2320a25 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -445,30 +445,30 @@ class ToolSelector(BaseModel): return self.model_dump() -class ToolProviderCredentialType(enum.StrEnum): +class CredentialType(enum.StrEnum): API_KEY = "api-key" OAUTH2 = "oauth2" def get_name(self): - if self == ToolProviderCredentialType.API_KEY: + if self == CredentialType.API_KEY: return "API KEY" - elif self == ToolProviderCredentialType.OAUTH2: + elif self == CredentialType.OAUTH2: return "AUTH" else: return self.value.replace("-", " ").upper() def is_editable(self): - return self == ToolProviderCredentialType.API_KEY + return self == CredentialType.API_KEY def is_validate_allowed(self): - return self == ToolProviderCredentialType.API_KEY + return self == CredentialType.API_KEY @classmethod def values(cls): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "ToolProviderCredentialType": + def of(cls, credential_type: str) -> "CredentialType": type_name = credential_type.lower() if type_name == "api-key": return cls.API_KEY diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7e37192979..d9010ce217 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -37,9 +37,9 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolInvokeFrom, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError @@ -240,7 +240,7 @@ class ToolManager: runtime=ToolRuntime( tenant_id=tenant_id, credentials=encrypter.decrypt(builtin_provider.credentials), - credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), + credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 4058e576f0..469a415ae8 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -19,7 +19,7 @@ from core.tools.entities.api_entities import ( ToolProviderCredentialApiEntity, ToolProviderCredentialInfoApiEntity, ) -from core.tools.entities.tool_entities import ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager @@ -96,7 +96,7 @@ class BuiltinToolManageService: @staticmethod def list_builtin_provider_credentials_schema( - provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str + provider_name: str, credential_type: CredentialType, tenant_id: str ): """ list builtin provider credentials schema @@ -123,7 +123,7 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider}") try: - if ToolProviderCredentialType.of(db_provider.credential_type).is_editable(): + if CredentialType.of(db_provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") @@ -166,7 +166,7 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( user_id: str, - api_type: ToolProviderCredentialType, + api_type: CredentialType, tenant_id: str, provider: str, credentials: dict, @@ -237,7 +237,7 @@ class BuiltinToolManageService: @staticmethod def generate_builtin_tool_provider_name( - tenant_id: str, provider: str, credential_type: ToolProviderCredentialType + tenant_id: str, provider: str, credential_type: CredentialType ) -> str: try: db_providers = ( diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1c3ef3d48c..2d35b769cd 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -15,8 +15,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController @@ -113,9 +113,9 @@ class ToolTransformService: schema = { x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema_by_type( - ToolProviderCredentialType.of(db_provider.credential_type) + CredentialType.of(db_provider.credential_type) if db_provider - else ToolProviderCredentialType.API_KEY + else CredentialType.API_KEY ) } @@ -139,7 +139,7 @@ class ToolTransformService: config=[ x.to_basic_provider_config() for x in provider_controller.get_credentials_schema_by_type( - ToolProviderCredentialType.of(db_provider.credential_type) + CredentialType.of(db_provider.credential_type) ) ], cache=ToolProviderCredentialsCache( @@ -329,7 +329,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=ToolProviderCredentialType.of(provider.credential_type), + credential_type=CredentialType.of(provider.credential_type), is_default=provider.is_default, credentials=credentials, )