diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py index 3e70ea5341..48ec3be5c8 100644 --- a/api/core/helper/provider_cache.py +++ b/api/core/helper/provider_cache.py @@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC): redis_client.delete(self.cache_key) -class GenericProviderCredentialsCache(ProviderCredentialsCache): - """Cache for generic provider credentials""" +class SingletonProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool single provider credentials""" - def __init__(self, tenant_id: str, identity_id: str): - super().__init__(tenant_id=tenant_id, identity_id=identity_id) + def __init__(self, tenant_id: str, provider_type: str, provider_identity: str): + super().__init__( + tenant_id=tenant_id, + provider_type=provider_type, + provider_identity=provider_identity, + ) def _generate_cache_key(self, **kwargs) -> str: tenant_id = kwargs["tenant_id"] - identity_id = kwargs["identity_id"] - return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" + provider_type = kwargs["provider_type"] + identity_name = kwargs["provider_identity"] + identity_id = f"{provider_type}.{identity_name}" + return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + class ToolProviderCredentialsCache(ProviderCredentialsCache): """Cache for tool provider credentials""" @@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache): tenant_id = kwargs["tenant_id"] provider = kwargs["provider"] credential_id = kwargs["credential_id"] - return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" class NoOpProviderCredentialCache: diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index bc9d861111..213f5c726a 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,16 +1,20 @@ +from core.helper.provider_cache import SingletonProviderCredentialsCache from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.encryption import create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter, cache = create_generic_encrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, - provider_type=payload.namespace, - provider_identity=payload.identity, + cache=SingletonProviderCredentialsCache( + tenant_id=tenant.id, + provider_type=payload.namespace, + provider_identity=payload.identity, + ), ) if payload.opt == "encrypt": diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 5b09ca2651..9ed29da4e6 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -47,7 +47,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) -from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -222,7 +222,7 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() @@ -248,11 +248,9 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, + controller=api_provider, ) return cast( ApiTool, @@ -740,15 +738,12 @@ class ToolManager: ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter.create_cached( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + controller=controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py index 4ceb3931ce..4aa5412a5e 100644 --- a/api/core/tools/utils/encryption.py +++ b/api/core/tools/utils/encryption.py @@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter -from core.helper.provider_cache import GenericProviderCredentialsCache +from core.helper.provider_cache import SingletonProviderCredentialsCache +from core.tools.__base.tool_provider import ToolProviderController class ProviderConfigCache(Protocol): @@ -123,13 +124,18 @@ class ProviderConfigEncrypter: return data -def create_generic_encrypter( - tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str -): - cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") - encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) - return encrypt, cache - - -def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + +def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): + cache = SingletonProviderCredentialsCache( + tenant_id=tenant_id, + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + encrypt = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_config_cache=cache, + ) + return encrypt, cache diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 01f1c5de7e..a1c5639e00 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider @@ -38,11 +38,9 @@ class PluginParameterService: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # check if credentials are required @@ -63,7 +61,7 @@ class PluginParameterService: if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) + credentials = encrypter.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 84e9930633..80badf2335 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -164,15 +164,11 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - - encrypted_credentials = tool_configuration.encrypt(credentials) - db_provider.credentials_str = json.dumps(encrypted_credentials) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) db.session.add(db_provider) db.session.commit() @@ -297,11 +293,9 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - encrypter, cache = create_generic_encrypter( + encrypter, cache = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) original_credentials = encrypter.decrypt(provider.credentials) @@ -416,11 +410,9 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58cff3af82..8e7b179ea7 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.encryption import create_encrypter +from core.tools.utils.encryption import create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient @@ -225,7 +225,7 @@ class BuiltinToolManageService: provider: str, provider_controller: BuiltinToolProviderController, ): - encrypter, cache = create_encrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() @@ -396,7 +396,7 @@ class BuiltinToolManageService: """ tool_provider = ToolProviderID(provider) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), @@ -608,7 +608,7 @@ class BuiltinToolManageService: session.add(custom_client_params) if client_params is not None: - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), @@ -647,7 +647,7 @@ class BuiltinToolManageService: if not isinstance(provider_controller, BuiltinToolProviderController): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 2dea0875be..cafcaecdf0 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.encryption import create_encrypter, create_generic_encrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider @@ -113,9 +113,7 @@ class ToolTransformService: schema = { x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema_by_type( - CredentialType.of(db_provider.credential_type) - if db_provider - else CredentialType.API_KEY + CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY ) } @@ -134,7 +132,7 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - encrypter, _ = create_encrypter( + encrypter, _ = create_provider_encrypter( tenant_id=db_provider.tenant_id, config=[ x.to_basic_provider_config() @@ -252,11 +250,9 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - encrypter, _ = create_generic_encrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # decrypt the credentials and mask the credentials