diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e94fcc195f..c782a4c37f 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -675,18 +675,17 @@ class ToolPluginOAuthApi(Resource): raise Forbidden() tenant_id = user.current_tenant_id - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( + oauth_client_params = BuiltinToolManageService.get_oauth_client( tenant_id=tenant_id, - provider=provider_name, - plugin_id=plugin_id, + provider=provider ) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") oauth_handler = OAuthHandler() context_id = OAuthProxyService.create_proxy_context( user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name ) - # TODO decrypt oauth params - oauth_params = plugin_oauth_config.oauth_params redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" authorization_url_response = oauth_handler.get_authorization_url( tenant_id=tenant_id, @@ -694,7 +693,7 @@ class ToolPluginOAuthApi(Resource): plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, - system_credentials=oauth_params, + system_credentials=oauth_client_params, ) response = make_response(jsonable_encoder(authorization_url_response)) response.set_cookie( @@ -724,12 +723,10 @@ class ToolOAuthCallback(Resource): user_id, tenant_id = context.get("user_id"), context.get("tenant_id") oauth_handler = OAuthHandler() - plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( - tenant_id=tenant_id, - provider=provider_name, - plugin_id=plugin_id, - ) - oauth_params = plugin_oauth_config.oauth_params + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" credentials = oauth_handler.get_credentials( tenant_id=tenant_id, @@ -737,7 +734,7 @@ class ToolOAuthCallback(Resource): plugin_id=plugin_id, provider=provider_name, redirect_uri=redirect_uri, - system_credentials=oauth_params, + system_credentials=oauth_client_params, request=request, ).credentials @@ -774,7 +771,8 @@ class ToolOAuthCustomClient(Resource): @account_initialization_required def post(self, provider): parser = reqparse.RequestParser() - parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json") + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") args = parser.parse_args() user = current_user @@ -782,18 +780,21 @@ class ToolOAuthCustomClient(Resource): if not user.is_admin_or_owner: raise Forbidden() - return BuiltinToolManageService.setup_oauth_custom_client( + return BuiltinToolManageService.save_custom_oauth_client_params( tenant_id=user.current_tenant_id, provider=provider, - client_params=args["client_params"], + client_params=args.get("client_params", {}), + enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), ) @setup_required @login_required @account_initialization_required def get(self, provider): - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=current_user.current_tenant_id, provider_name=provider + return jsonable_encoder( + BuiltinToolManageService.get_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) ) diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index ebb503a8b3..483fbe13d7 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -85,4 +85,7 @@ class ToolProviderCredentialApiEntity(BaseModel): class ToolProviderCredentialInfoApiEntity(BaseModel): supported_credential_types: list[str] = Field(description="The supported credential types of the provider") - credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") \ No newline at end of file + is_oauth_custom_client_enabled: bool = Field( + default=False, description="Whether the OAuth custom client is enabled for the provider" + ) + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 2abb234a83..4058e576f0 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,7 +2,7 @@ import json import logging import re from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional from sqlalchemy.orm import Session @@ -21,6 +21,7 @@ from core.tools.entities.api_entities import ( ) from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import create_encrypter @@ -41,7 +42,12 @@ class BuiltinToolManageService: get builtin tool provider oauth client schema """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return provider.get_oauth_client_schema() + return { + "schema": provider.get_oauth_client_schema(), + "is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ), + } @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: @@ -139,7 +145,7 @@ class BuiltinToolManageService: # encrypt credentials db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) - + cache.delete() # update name if provided @@ -279,20 +285,16 @@ class BuiltinToolManageService: """ with db.session.no_autoflush: providers = ( - db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() + db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider_name) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .all() ) if len(providers) == 0: return [] - default_provider = sorted( - providers, - key=lambda p: ( - not getattr(p, "is_default", False), - getattr(p, "created_at", None) or 0, - ), - )[0] - + default_provider = providers[0] default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) encrypter, cache = BuiltinToolManageService.create_tool_encrypter( @@ -319,6 +321,7 @@ class BuiltinToolManageService: credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) credential_info = ToolProviderCredentialInfoApiEntity( supported_credential_types=supported_credential_types, + is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), credentials=credentials, ) @@ -368,30 +371,61 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_oauth_client( - tenant_id: str, provider: str, plugin_id: str - ) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]: + def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: + """ + check if oauth custom client is enabled + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + return user_client is not None and user_client.enabled + + @staticmethod + def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None: """ get builtin tool provider """ - with Session(db.engine) as session: - user_client = ( + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, enabled=True, ) .first() ) + oauth_params: dict[str, Any] | None = None if user_client: - return user_client + oauth_params = encrypter.decrypt(user_client.oauth_params) + return oauth_params + + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + if system_client: + oauth_params = encrypter.decrypt(system_client.oauth_params) - system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() - if system_client is None: - raise ValueError("no oauth available client config found for this tool provider") - return system_client + return oauth_params @staticmethod def get_builtin_tool_provider_icon(provider: str): @@ -533,12 +567,79 @@ class BuiltinToolManageService: ) @staticmethod - def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict): + def save_custom_oauth_client_params( + tenant_id: str, + provider: str, + client_params: Optional[dict] = None, + enable_oauth_custom_client: Optional[bool] = None, + ): """ setup oauth custom client """ + if client_params is None and enable_oauth_custom_client is None: + return {"result": "success"} + + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + with Session(db.engine) as session: + custom_client_params = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + + # if the record does not exist, create a basic record + if custom_client_params is None: + custom_client_params = ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + session.add(custom_client_params) + + if client_params is not None: + encrypter, _ = create_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params)) + + if enable_oauth_custom_client is not None: + custom_client_params.enabled = enable_oauth_custom_client + + session.commit() + return {"result": "success"} + + @staticmethod + def get_custom_oauth_client_params(tenant_id: str, provider: str): + """ + get custom oauth client params + """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) + custom_oauth_client_params: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + if custom_oauth_client_params is None: + return {} + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller: raise ToolProviderNotFoundError(f"Provider {provider} not found") @@ -551,17 +652,4 @@ class BuiltinToolManageService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - - # encrypt credentials - encrypted_credentials = encrypter.encrypt(client_params) - session.add( - ToolOAuthTenantClient( - tenant_id=tenant_id, - plugin_id=tool_provider.plugin_id, - provider=tool_provider.provider_name, - enabled=True, - encrypted_oauth_params=json.dumps(encrypted_credentials), - ) - ) - session.commit() - return {"result": "success"} + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))