From 478c156f7d82ce3da08b70e0bd992a983bf89659 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 15:28:36 +0800 Subject: [PATCH] feat(oauth&mcp): refactor credential encrypter --- .../console/workspace/tool_providers.py | 17 +-- .../plugin/backwards_invocation/encrypt.py | 2 +- api/core/tools/tool_manager.py | 16 +- api/core/tools/utils/configuration.py | 138 +----------------- api/core/tools/utils/encryption.py | 134 +++++++++++++++++ api/models/tools.py | 13 +- .../plugin/plugin_parameter_service.py | 2 +- .../tools/api_tools_manage_service.py | 2 +- .../tools/builtin_tools_manage_service.py | 2 +- api/services/tools/mcp_tools_mange_service.py | 2 +- api/services/tools/tools_transform_service.py | 4 +- 11 files changed, 158 insertions(+), 174 deletions(-) create mode 100644 api/core/tools/utils/encryption.py diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index d10559513f..e6b67b732a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -2,7 +2,6 @@ import io from urllib.parse import urlparse from flask import make_response, redirect, request, send_file -from flask import redirect, send_file from flask_login import current_user from flask_restful import ( Resource, @@ -18,11 +17,6 @@ from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.error import MCPAuthError from core.mcp.mcp_client import MCPClient -from controllers.console.wraps import ( - account_initialization_required, - enterprise_license_required, - setup_required, -) from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler @@ -97,10 +91,7 @@ class ToolBuiltinProviderInfoApi(Resource): @login_required @account_initialization_required def get(self, provider): - user = current_user - - user_id = user.id - tenant_id = user.current_tenant_id + tenant_id = current_user.current_tenant_id return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) @@ -695,10 +686,7 @@ class ToolPluginOAuthApi(Resource): raise Forbidden() tenant_id = user.current_tenant_id - oauth_client_params = BuiltinToolManageService.get_oauth_client( - tenant_id=tenant_id, - provider=provider - ) + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) if oauth_client_params is None: raise Forbidden("no oauth available client config found for this tool provider") @@ -851,6 +839,7 @@ api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") + class ToolProviderMCPApi(Resource): @setup_required @login_required diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index bfe9ffa4b0..bc9d861111 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,5 +1,5 @@ from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import create_generic_encrypter +from core.tools.utils.encryption import create_generic_encrypter from models.account import Tenant diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 819760d783..c01041a9fc 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -46,14 +46,12 @@ from core.tools.entities.tool_entities import ( ToolParameter, ToolProviderType, ) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( - ProviderConfigEncrypter, ToolParameterConfigurationManager, - create_encrypter, - create_generic_encrypter, ) +from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider @@ -762,15 +760,15 @@ class ToolManager: ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration - tool_configuration = create_encrypter( + encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider, credential_id=provider_obj.id + ), ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 98fdfcee15..aceba6e69f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,9 +1,7 @@ from copy import deepcopy -from typing import Any, Optional, Protocol +from typing import Any -from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter -from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( @@ -12,140 +10,6 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigCache(Protocol): - """ - Interface for provider configuration cache operations - """ - - def get(self) -> Optional[dict]: - """Get cached provider configuration""" - ... - - def set(self, config: dict[str, Any]) -> None: - """Cache provider configuration""" - ... - - def delete(self) -> None: - """Delete cached provider configuration""" - ... - - -class ProviderConfigEncrypter: - tenant_id: str - config: list[BasicProviderConfig] - provider_config_cache: ProviderConfigCache - - def __init__( - self, - tenant_id: str, - config: list[BasicProviderConfig], - provider_config_cache: ProviderConfigCache, - ): - self.tenant_id = tenant_id - self.config = config - self.provider_config_cache = provider_config_cache - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, Any]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - if use_cache: - cached_credentials = self.provider_config_cache.get() - if cached_credentials: - return cached_credentials - - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - if use_cache: - self.provider_config_cache.set(data) - return data - - -def create_encrypter( - tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache -): - return ProviderConfigEncrypter( - tenant_id=tenant_id, config=config, provider_config_cache=cache - ), cache - - -def create_generic_encrypter( - tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str -): - cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") - encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) - return encrypt, cache - - class ToolParameterConfigurationManager: """ Tool parameter configuration manager diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py new file mode 100644 index 0000000000..dce68d98c8 --- /dev/null +++ b/api/core/tools/utils/encryption.py @@ -0,0 +1,134 @@ +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.provider_cache import GenericProviderCredentialsCache + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + self.provider_config_cache.set(data) + return data + + +def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + + +def create_generic_encrypter( + tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str +): + cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") + encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) + return encrypt, cache diff --git a/api/models/tools.py b/api/models/tools.py index 05a4920a9c..ded05a55b9 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -10,10 +10,12 @@ from sqlalchemy.orm import Mapped, mapped_column from core.file import helpers as file_helpers from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import Tool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.utils.encryption import create_encrypter from models.base import Base from .engine import db @@ -327,17 +329,14 @@ class MCPToolProvider(Base): @property def decrypted_credentials(self) -> dict: from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.configuration import ProviderConfigEncrypter provider_controller = MCPToolProviderController._from_db(self) - tool_configuration = ProviderConfigEncrypter( + return create_encrypter( tenant_id=self.tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, - ) - return tool_configuration.decrypt(self.credentials, use_cache=False) + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), + )[0].decrypt(self.credentials) class ToolModelInvoke(Base): diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e2..01f1c5de7e 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ff84b4318b..84e9930633 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter +from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 469a415ae8..58cff3af82 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import create_encrypter +from core.tools.utils.encryption import create_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index b2a88738f6..75046a7312 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -12,7 +12,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 7cd4eb0c7e..94df322d0d 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,8 +5,8 @@ from typing import Any, Optional, Union, cast from yarl import URL from configs import dify_config -from core.mcp.types import Tool as MCPTool from core.helper.provider_cache import ToolProviderCredentialsCache +from core.mcp.types import Tool as MCPTool from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -21,7 +21,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider