feat(oauth): refactor tool encryption utils

pull/22036/head
Harry 7 months ago
parent eaefa1b7e6
commit 0dc5bfb2c7

@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC):
redis_client.delete(self.cache_key)
class GenericProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for generic provider credentials"""
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, identity_id: str):
super().__init__(tenant_id=tenant_id, identity_id=identity_id)
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
identity_id = kwargs["identity_id"]
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}"
provider_type = kwargs["provider_type"]
identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials"""
@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:

@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.encryption import create_generic_encrypter
from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter, cache = create_generic_encrypter(
encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
provider_type=payload.namespace,
provider_identity=payload.identity,
cache=SingletonProviderCredentialsCache(
tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
)
if payload.opt == "encrypt":

@ -47,7 +47,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import ProviderConfigEncrypter, create_encrypter, create_generic_encrypter
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -222,7 +222,7 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
encrypter, _ = create_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
@ -248,11 +248,9 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_generic_encrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
controller=api_provider,
)
return cast(
ApiTool,
@ -740,15 +738,12 @@ class ToolManager:
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter.create_cached(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
controller=controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
@ -123,13 +124,18 @@ class ProviderConfigEncrypter:
return data
def create_generic_encrypter(
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
):
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
return encrypt, cache
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
@ -38,11 +38,9 @@ class PluginParameterService:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# check if credentials are required
@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials)
credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@ -297,11 +293,9 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
encrypter, cache = create_generic_encrypter(
encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
original_credentials = encrypter.decrypt(provider.credentials)
@ -416,11 +410,9 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
encrypter, _ = create_generic_encrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential

@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_encrypter
from core.tools.utils.encryption import create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
@ -225,7 +225,7 @@ class BuiltinToolManageService:
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_encrypter(
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
@ -396,7 +396,7 @@ class BuiltinToolManageService:
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
@ -608,7 +608,7 @@ class BuiltinToolManageService:
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
@ -647,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),

@ -20,7 +20,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -113,9 +113,7 @@ class ToolTransformService:
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
if db_provider
else CredentialType.API_KEY
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
)
}
@ -134,7 +132,7 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
encrypter, _ = create_encrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[
x.to_basic_provider_config()
@ -252,11 +250,9 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
encrypter, _ = create_generic_encrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# decrypt the credentials and mask the credentials

Loading…
Cancel
Save