feat(oauth): refactor tool encryption utils

pull/22036/head
Harry 11 months ago
parent eaefa1b7e6
commit 0dc5bfb2c7

@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC):
redis_client.delete(self.cache_key) redis_client.delete(self.cache_key)
class GenericProviderCredentialsCache(ProviderCredentialsCache): class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for generic provider credentials""" """Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, identity_id: str): def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(tenant_id=tenant_id, identity_id=identity_id) super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str: def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"] tenant_id = kwargs["tenant_id"]
identity_id = kwargs["identity_id"] provider_type = kwargs["provider_type"]
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache): class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials""" """Cache for tool provider credentials"""
@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
tenant_id = kwargs["tenant_id"] tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"] provider = kwargs["provider"]
credential_id = kwargs["credential_id"] credential_id = kwargs["credential_id"]
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache: class NoOpProviderCredentialCache:

@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.encryption import create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant from models.account import Tenant
class PluginEncrypter: class PluginEncrypter:
@classmethod @classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter, cache = create_generic_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id, tenant_id=tenant.id,
config=payload.config, config=payload.config,
provider_type=payload.namespace, cache=SingletonProviderCredentialsCache(
provider_identity=payload.identity, tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
) )
if payload.opt == "encrypt": if payload.opt == "encrypt":

@ -47,7 +47,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
) )
from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -222,7 +222,7 @@ class ToolManager:
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -248,11 +248,9 @@ class ToolManager:
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], controller=api_provider,
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
) )
return cast( return cast(
ApiTool, ApiTool,
@ -740,15 +738,12 @@ class ToolManager:
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
) )
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter.create_cached( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], controller=controller,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try: try:
icon = json.loads(provider_obj.icon) icon = json.loads(provider_obj.icon)

@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol): class ProviderConfigCache(Protocol):
@ -123,13 +124,18 @@ class ProviderConfigEncrypter:
return data return data
def create_generic_encrypter( def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
):
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
return encrypt, cache
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider
@ -38,11 +38,9 @@ class PluginParameterService:
case "tool": case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
# check if credentials are required # check if credentials are required
@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None: if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials") raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials) credentials = encrypter.decrypt(db_record.credentials)
case _: case _:
raise ValueError(f"Invalid provider type: {provider_type}") raise ValueError(f"Invalid provider type: {provider_type}")

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
) )
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials # encrypt credentials
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db.session.add(db_provider) db.session.add(db_provider)
db.session.commit() db.session.commit()
@ -297,11 +293,9 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists # get original credentials if exists
encrypter, cache = create_generic_encrypter( encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
original_credentials = encrypter.decrypt(provider.credentials) original_credentials = encrypter.decrypt(provider.credentials)
@ -416,11 +410,9 @@ class ApiToolManageService:
# decrypt credentials # decrypt credentials
if db_provider.id: if db_provider.id:
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
decrypted_credentials = encrypter.decrypt(credentials) decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential

@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_encrypter from core.tools.utils.encryption import create_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
@ -225,7 +225,7 @@ class BuiltinToolManageService:
provider: str, provider: str,
provider_controller: BuiltinToolProviderController, provider_controller: BuiltinToolProviderController,
): ):
encrypter, cache = create_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -396,7 +396,7 @@ class BuiltinToolManageService:
""" """
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
@ -608,7 +608,7 @@ class BuiltinToolManageService:
session.add(custom_client_params) session.add(custom_client_params)
if client_params is not None: if client_params is not None:
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
@ -647,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, BuiltinToolProviderController): if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider") raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),

@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -113,9 +113,7 @@ class ToolTransformService:
schema = { schema = {
x.to_basic_provider_config().name: x x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type( for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
if db_provider
else CredentialType.API_KEY
) )
} }
@ -134,7 +132,7 @@ class ToolTransformService:
credentials = db_provider.credentials credentials = db_provider.credentials
# init tool configuration # init tool configuration
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -252,11 +250,9 @@ class ToolTransformService:
if decrypt_credentials: if decrypt_credentials:
# init tool configuration # init tool configuration
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials

Loading…
Cancel
Save