From 26b46b88c9acb6ddedd213dee5beaee7d6ddb26b Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 14:25:33 +0800 Subject: [PATCH 1/4] feat(oauth): add multi credentials support --- api/core/plugin/impl/tool.py | 4 +- api/core/tools/__base/tool_runtime.py | 3 +- api/core/tools/plugin_tool/tool.py | 1 + api/core/tools/tool_manager.py | 48 +++++++++++++++--------- api/core/workflow/nodes/tool/entities.py | 1 + api/services/app_dsl_service.py | 5 +++ 6 files changed, 42 insertions(+), 20 deletions(-) diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..f84e8c6c5e 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: ToolProviderCredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb77..51e339bed1 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..aef2677c36 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e9423a6c49..7e37192979 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -4,7 +4,7 @@ import mimetypes from collections.abc import Generator from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from yarl import URL @@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ToolInvokeFrom, ToolParameter, + ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError @@ -148,6 +149,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + credential_id: Optional[str] = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: """ get the tool runtime @@ -158,6 +160,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -185,19 +188,31 @@ class ToolManager: if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + if credential_id: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found") + else: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() ) - .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() - ) - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + if builtin_provider is None: + raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") else: builtin_provider = ( db.session.query(BuiltinToolProvider) @@ -209,8 +224,6 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials encrypter, _ = create_encrypter( tenant_id=tenant_id, config=[ @@ -221,15 +234,13 @@ class ToolManager: tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id ), ) - - decrypted_credentials = encrypter.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(builtin_provider.credentials), + credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -362,6 +373,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) runtime_parameters = {} parameters = tool_runtime.get_merged_runtime_parameters() diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 21023d4ab7..2ce6ac3fc1 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..f53048a690 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -582,6 +582,11 @@ class AppDslService: cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if node.get("data", {}).get("type", "") == NodeType.TOOL.value: + node["data"]["credential_id"] = None + + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ From 9f053f3bbcf2dd20f59b644cc61b081e1c61c970 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 14:29:17 +0800 Subject: [PATCH 2/4] feat(oauth): rename ToolProviderCredentialType to CredentialType for consistency --- api/controllers/console/workspace/tool_providers.py | 10 +++++----- api/core/plugin/impl/tool.py | 4 ++-- api/core/tools/__base/tool_runtime.py | 4 ++-- api/core/tools/builtin_tool/provider.py | 12 ++++++------ api/core/tools/entities/api_entities.py | 4 ++-- api/core/tools/entities/tool_entities.py | 12 ++++++------ api/core/tools/tool_manager.py | 4 ++-- api/services/tools/builtin_tools_manage_service.py | 10 +++++----- api/services/tools/tools_transform_service.py | 10 +++++----- 9 files changed, 35 insertions(+), 35 deletions(-) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c782a4c37f..f71cf34d4a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -19,7 +19,7 @@ from controllers.console.wraps import ( from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.oauth import OAuthHandler -from core.tools.entities.tool_entities import ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required @@ -122,7 +122,7 @@ class ToolBuiltinProviderAddApi(Resource): parser.add_argument("type", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - if args["type"] not in ToolProviderCredentialType.values(): + if args["type"] not in CredentialType.values(): raise ValueError(f"Invalid credential type: {args['type']}") return BuiltinToolManageService.add_builtin_tool_provider( @@ -131,7 +131,7 @@ class ToolBuiltinProviderAddApi(Resource): provider=provider, credentials=args["credentials"], name=args["name"], - api_type=ToolProviderCredentialType.of(args["type"]), + api_type=CredentialType.of(args["type"]), ) @@ -378,7 +378,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): return jsonable_encoder( BuiltinToolManageService.list_builtin_provider_credentials_schema( - provider, ToolProviderCredentialType.of(credential_type), tenant_id + provider, CredentialType.of(credential_type), tenant_id ) ) @@ -747,7 +747,7 @@ class ToolOAuthCallback(Resource): tenant_id=tenant_id, provider=provider, credentials=dict(credentials), - api_type=ToolProviderCredentialType.OAUTH2, + api_type=CredentialType.OAUTH2, ) return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success") diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index f84e8c6c5e..04225f95ee 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): @@ -78,7 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], - credential_type: ToolProviderCredentialType, + credential_type: CredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 51e339bed1..1068b07062 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom class ToolRuntime(BaseModel): @@ -17,7 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) - credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY + credential_type: Optional[CredentialType] = CredentialType.API_KEY runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index ce85a37501..f9a03e40ae 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -8,9 +8,9 @@ from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ( + CredentialType, OAuthSchema, ToolEntity, - ToolProviderCredentialType, ToolProviderEntity, ToolProviderType, ) @@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value) + return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: """ @@ -120,9 +120,9 @@ class BuiltinToolProviderController(ToolProviderController): :param credential_type: the type of the credential :return: the credentials schema of the provider """ - if credential_type == ToolProviderCredentialType.OAUTH2.value: + if credential_type == CredentialType.OAUTH2.value: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] - if credential_type == ToolProviderCredentialType.API_KEY.value: + if credential_type == CredentialType.API_KEY.value: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] raise ValueError(f"Invalid credential type: {credential_type}") @@ -140,9 +140,9 @@ class BuiltinToolProviderController(ToolProviderController): """ types = [] if self.entity.credentials_schema is not None: - types.append(ToolProviderCredentialType.API_KEY.value) + types.append(CredentialType.API_KEY.value) if self.entity.oauth_schema is not None: - types.append(ToolProviderCredentialType.OAUTH2.value) + types.append(CredentialType.OAUTH2.value) return types def get_tools(self) -> list[BuiltinTool]: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 483fbe13d7..687883ce19 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType +from core.tools.entities.tool_entities import CredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -76,7 +76,7 @@ class ToolProviderCredentialApiEntity(BaseModel): id: str = Field(description="The unique id of the credential") name: str = Field(description="The name of the credential") provider: str = Field(description="The provider of the credential") - credential_type: ToolProviderCredentialType = Field(description="The type of the credential") + credential_type: CredentialType = Field(description="The type of the credential") is_default: bool = Field( default=False, description="Whether the credential is the default credential for the provider in the workspace" ) diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index f5cb768205..aad2320a25 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -445,30 +445,30 @@ class ToolSelector(BaseModel): return self.model_dump() -class ToolProviderCredentialType(enum.StrEnum): +class CredentialType(enum.StrEnum): API_KEY = "api-key" OAUTH2 = "oauth2" def get_name(self): - if self == ToolProviderCredentialType.API_KEY: + if self == CredentialType.API_KEY: return "API KEY" - elif self == ToolProviderCredentialType.OAUTH2: + elif self == CredentialType.OAUTH2: return "AUTH" else: return self.value.replace("-", " ").upper() def is_editable(self): - return self == ToolProviderCredentialType.API_KEY + return self == CredentialType.API_KEY def is_validate_allowed(self): - return self == ToolProviderCredentialType.API_KEY + return self == CredentialType.API_KEY @classmethod def values(cls): return [item.value for item in cls] @classmethod - def of(cls, credential_type: str) -> "ToolProviderCredentialType": + def of(cls, credential_type: str) -> "CredentialType": type_name = credential_type.lower() if type_name == "api-key": return cls.API_KEY diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 7e37192979..d9010ce217 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -37,9 +37,9 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolInvokeFrom, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.errors import ToolProviderNotFoundError @@ -240,7 +240,7 @@ class ToolManager: runtime=ToolRuntime( tenant_id=tenant_id, credentials=encrypter.decrypt(builtin_provider.credentials), - credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), + credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 4058e576f0..469a415ae8 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -19,7 +19,7 @@ from core.tools.entities.api_entities import ( ToolProviderCredentialApiEntity, ToolProviderCredentialInfoApiEntity, ) -from core.tools.entities.tool_entities import ToolProviderCredentialType +from core.tools.entities.tool_entities import CredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager @@ -96,7 +96,7 @@ class BuiltinToolManageService: @staticmethod def list_builtin_provider_credentials_schema( - provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str + provider_name: str, credential_type: CredentialType, tenant_id: str ): """ list builtin provider credentials schema @@ -123,7 +123,7 @@ class BuiltinToolManageService: raise ValueError(f"you have not added provider {provider}") try: - if ToolProviderCredentialType.of(db_provider.credential_type).is_editable(): + if CredentialType.of(db_provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") @@ -166,7 +166,7 @@ class BuiltinToolManageService: @staticmethod def add_builtin_tool_provider( user_id: str, - api_type: ToolProviderCredentialType, + api_type: CredentialType, tenant_id: str, provider: str, credentials: dict, @@ -237,7 +237,7 @@ class BuiltinToolManageService: @staticmethod def generate_builtin_tool_provider_name( - tenant_id: str, provider: str, credential_type: ToolProviderCredentialType + tenant_id: str, provider: str, credential_type: CredentialType ) -> str: try: db_providers = ( diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1c3ef3d48c..2d35b769cd 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -15,8 +15,8 @@ from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolParameter, - ToolProviderCredentialType, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController @@ -113,9 +113,9 @@ class ToolTransformService: schema = { x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema_by_type( - ToolProviderCredentialType.of(db_provider.credential_type) + CredentialType.of(db_provider.credential_type) if db_provider - else ToolProviderCredentialType.API_KEY + else CredentialType.API_KEY ) } @@ -139,7 +139,7 @@ class ToolTransformService: config=[ x.to_basic_provider_config() for x in provider_controller.get_credentials_schema_by_type( - ToolProviderCredentialType.of(db_provider.credential_type) + CredentialType.of(db_provider.credential_type) ) ], cache=ToolProviderCredentialsCache( @@ -329,7 +329,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=ToolProviderCredentialType.of(provider.credential_type), + credential_type=CredentialType.of(provider.credential_type), is_default=provider.is_default, credentials=credentials, ) From eaefa1b7e6adf83177466f787c4ee20ec9a6af02 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 15:55:23 +0800 Subject: [PATCH 3/4] feat(oauth): refactor encryption --- .../plugin/backwards_invocation/encrypt.py | 2 +- api/core/tools/tool_manager.py | 4 +- api/core/tools/utils/configuration.py | 137 +----------------- api/core/tools/utils/encryption.py | 135 +++++++++++++++++ .../plugin/plugin_parameter_service.py | 2 +- .../tools/api_tools_manage_service.py | 2 +- .../tools/builtin_tools_manage_service.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- 8 files changed, 142 insertions(+), 144 deletions(-) create mode 100644 api/core/tools/utils/encryption.py diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index bfe9ffa4b0..bc9d861111 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,5 +1,5 @@ from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import create_generic_encrypter +from core.tools.utils.encryption import create_generic_encrypter from models.account import Tenant diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d9010ce217..5b09ca2651 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -45,11 +45,9 @@ from core.tools.entities.tool_entities import ( from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( - ProviderConfigEncrypter, ToolParameterConfigurationManager, - create_encrypter, - create_generic_encrypter, ) +from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 6bd6309205..aceba6e69f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,9 +1,7 @@ from copy import deepcopy -from typing import Any, Optional, Protocol +from typing import Any -from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter -from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( @@ -12,139 +10,6 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigCache(Protocol): - """ - Interface for provider configuration cache operations - """ - - def get(self) -> Optional[dict]: - """Get cached provider configuration""" - ... - - def set(self, config: dict[str, Any]) -> None: - """Cache provider configuration""" - ... - - def delete(self) -> None: - """Delete cached provider configuration""" - ... - - -class ProviderConfigEncrypter: - tenant_id: str - config: list[BasicProviderConfig] - provider_config_cache: ProviderConfigCache - - def __init__( - self, - tenant_id: str, - config: list[BasicProviderConfig], - provider_config_cache: ProviderConfigCache, - ): - self.tenant_id = tenant_id - self.config = config - self.provider_config_cache = provider_config_cache - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str]) -> dict[str, Any]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - cached_credentials = self.provider_config_cache.get() - if cached_credentials: - return cached_credentials - - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - - self.provider_config_cache.set(data) - return data - - -def create_encrypter( - tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache -): - return ProviderConfigEncrypter( - tenant_id=tenant_id, config=config, provider_config_cache=cache - ), cache - - -def create_generic_encrypter( - tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str -): - cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") - encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) - return encrypt, cache - - class ToolParameterConfigurationManager: """ Tool parameter configuration manager diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py new file mode 100644 index 0000000000..4ceb3931ce --- /dev/null +++ b/api/core/tools/utils/encryption.py @@ -0,0 +1,135 @@ +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.provider_cache import GenericProviderCredentialsCache + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + self.provider_config_cache.set(data) + return data + + +def create_generic_encrypter( + tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str +): + cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") + encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) + return encrypt, cache + + +def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e2..01f1c5de7e 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ff84b4318b..84e9930633 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter +from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 469a415ae8..58cff3af82 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import create_encrypter +from core.tools.utils.encryption import create_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 2d35b769cd..2dea0875be 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider From 0dc5bfb2c7652b2b0c3a4c3ba6a9081ac9158827 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 4 Jul 2025 17:08:07 +0800 Subject: [PATCH 4/4] feat(oauth): refactor tool encryption utils --- api/core/helper/provider_cache.py | 21 ++++++++++----- .../plugin/backwards_invocation/encrypt.py | 12 ++++++--- api/core/tools/tool_manager.py | 19 +++++--------- api/core/tools/utils/encryption.py | 26 ++++++++++++------- .../plugin/plugin_parameter_service.py | 10 +++---- .../tools/api_tools_manage_service.py | 24 ++++++----------- .../tools/builtin_tools_manage_service.py | 10 +++---- api/services/tools/tools_transform_service.py | 14 ++++------ 8 files changed, 67 insertions(+), 69 deletions(-) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 3e70ea5341..48ec3be5c8 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC): redis_client.delete(self.cache_key) -class GenericProviderCredentialsCache(ProviderCredentialsCache): - """Cache for generic provider credentials""" +class SingletonProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool single provider credentials""" - def __init__(self, tenant_id: str, identity_id: str): - super().__init__(tenant_id=tenant_id, identity_id=identity_id) + def __init__(self, tenant_id: str, provider_type: str, provider_identity: str): + super().__init__( + tenant_id=tenant_id, + provider_type=provider_type, + provider_identity=provider_identity, + ) def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] - identity_id = kwargs["identity_id"] - return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" + provider_type = kwargs["provider_type"] + identity_name = kwargs["provider_identity"] + identity_id = f"{provider_type}.{identity_name}" + return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + class ToolProviderCredentialsCache(ProviderCredentialsCache): """Cache for tool provider credentials""" @@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): tenant_id = kwargs["tenant_id"] provider = kwargs["provider"] credential_id = kwargs["credential_id"] - return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" class NoOpProviderCredentialCache: diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index bc9d861111..213f5c726a 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,16 +1,20 @@ +from core.helper.provider_cache import SingletonProviderCredentialsCache from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.encryption import create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter, cache = create_generic_encrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, - provider_type=payload.namespace, - provider_identity=payload.identity, + cache=SingletonProviderCredentialsCache( + tenant_id=tenant.id, + provider_type=payload.namespace, + provider_identity=payload.identity, + ), ) if payload.opt == "encrypt": diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5b09ca2651..9ed29da4e6 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -47,7 +47,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) -from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -222,7 +222,7 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() @@ -248,11 +248,9 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, + controller=api_provider, ) return cast( ApiTool, @@ -740,15 +738,12 @@ class ToolManager: ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter.create_cached( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + controller=controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 4ceb3931ce..4aa5412a5e 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter -from core.helper.provider_cache import GenericProviderCredentialsCache +from core.helper.provider_cache import SingletonProviderCredentialsCache +from core.tools.__base.tool_provider import ToolProviderController class ProviderConfigCache(Protocol): @@ -123,13 +124,18 @@ class ProviderConfigEncrypter: return data -def create_generic_encrypter( - tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str -): - cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") - encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) - return encrypt, cache - - -def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + +def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): + cache = SingletonProviderCredentialsCache( + tenant_id=tenant_id, + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + encrypt = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_config_cache=cache, + ) + return encrypt, cache diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 01f1c5de7e..a1c5639e00 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider @@ -38,11 +38,9 @@ class PluginParameterService: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # check if credentials are required @@ -63,7 +61,7 @@ class PluginParameterService: if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) + credentials = encrypter.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 84e9930633..80badf2335 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -164,15 +164,11 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - - encrypted_credentials = tool_configuration.encrypt(credentials) - db_provider.credentials_str = json.dumps(encrypted_credentials) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) db.session.add(db_provider) db.session.commit() @@ -297,11 +293,9 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - encrypter, cache = create_generic_encrypter( + encrypter, cache = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) original_credentials = encrypter.decrypt(provider.credentials) @@ -416,11 +410,9 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58cff3af82..8e7b179ea7 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import create_encrypter +from core.tools.utils.encryption import create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient @@ -225,7 +225,7 @@ class BuiltinToolManageService: provider: str, provider_controller: BuiltinToolProviderController, ): - encrypter, cache = create_encrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() @@ -396,7 +396,7 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), @@ -608,7 +608,7 @@ class BuiltinToolManageService: session.add(custom_client_params) if client_params is not None: - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), @@ -647,7 +647,7 @@ class BuiltinToolManageService: if not isinstance(provider_controller, BuiltinToolProviderController): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 2dea0875be..cafcaecdf0 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.encryption import create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -113,9 +113,7 @@ class ToolTransformService: schema = { x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema_by_type( - CredentialType.of(db_provider.credential_type) - if db_provider - else CredentialType.API_KEY + CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY ) } @@ -134,7 +132,7 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=db_provider.tenant_id, config=[ x.to_basic_provider_config() @@ -252,11 +250,9 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # decrypt the credentials and mask the credentials