Merge branch 'feat/tool-plugin-oauth' into deploy/dev

# Conflicts:
#	api/core/plugin/backwards_invocation/encrypt.py
#	api/core/tools/tool_manager.py
#	api/core/tools/utils/encryption.py
#	api/services/plugin/plugin_parameter_service.py
#	api/services/tools/api_tools_manage_service.py
#	api/services/tools/builtin_tools_manage_service.py
#	api/services/tools/tools_transform_service.py
pull/22036/head
Harry 11 months ago
commit 5298e06763

@ -37,16 +37,23 @@ class ProviderCredentialsCache(ABC):
redis_client.delete(self.cache_key) redis_client.delete(self.cache_key)
class GenericProviderCredentialsCache(ProviderCredentialsCache): class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for generic provider credentials""" """Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, identity_id: str): def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(tenant_id=tenant_id, identity_id=identity_id) super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str: def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"] tenant_id = kwargs["tenant_id"]
identity_id = kwargs["identity_id"] provider_type = kwargs["provider_type"]
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}" identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache): class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials""" """Cache for tool provider credentials"""
@ -58,7 +65,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
tenant_id = kwargs["tenant_id"] tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"] provider = kwargs["provider"]
credential_id = kwargs["credential_id"] credential_id = kwargs["credential_id"]
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache: class NoOpProviderCredentialCache:

@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.encryption import create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant from models.account import Tenant
class PluginEncrypter: class PluginEncrypter:
@classmethod @classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter, cache = create_generic_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id, tenant_id=tenant.id,
config=payload.config, config=payload.config,
provider_type=payload.namespace, cache=SingletonProviderCredentialsCache(
provider_identity=payload.identity, tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
) )
if payload.opt == "encrypt": if payload.opt == "encrypt":

@ -4,6 +4,7 @@ import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL from yarl import URL
@ -51,7 +52,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
) )
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -226,7 +227,7 @@ class ToolManager:
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -252,11 +253,9 @@ class ToolManager:
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], controller=api_provider,
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
) )
return cast( return cast(
ApiTool, ApiTool,
@ -760,12 +759,9 @@ class ToolManager:
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
) )
# init tool configuration # init tool configuration
encrypter, _ = create_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], controller=controller,
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider, credential_id=provider_obj.id
),
) )
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))

@ -3,7 +3,8 @@ from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol): class ProviderConfigCache(Protocol):
@ -118,17 +119,23 @@ class ProviderConfigEncrypter:
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception: except Exception:
pass pass
self.provider_config_cache.set(data) self.provider_config_cache.set(data)
return data return data
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
def create_generic_encrypter( cache = SingletonProviderCredentialsCache(
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str tenant_id=tenant_id,
): provider_type=controller.provider_type.value,
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}") provider_identity=controller.entity.identity.name,
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache) )
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache return encrypt, cache

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider
@ -38,11 +38,9 @@ class PluginParameterService:
case "tool": case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
# check if credentials are required # check if credentials are required
@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None: if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials") raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials) credentials = encrypter.decrypt(db_record.credentials)
case _: case _:
raise ValueError(f"Invalid provider type: {provider_type}") raise ValueError(f"Invalid provider type: {provider_type}")

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
) )
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials # encrypt credentials
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db.session.add(db_provider) db.session.add(db_provider)
db.session.commit() db.session.commit()
@ -297,11 +293,9 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists # get original credentials if exists
encrypter, cache = create_generic_encrypter( encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
original_credentials = encrypter.decrypt(provider.credentials) original_credentials = encrypter.decrypt(provider.credentials)
@ -416,11 +410,9 @@ class ApiToolManageService:
# decrypt credentials # decrypt credentials
if db_provider.id: if db_provider.id:
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
decrypted_credentials = encrypter.decrypt(credentials) decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential

@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.encryption import create_encrypter from core.tools.utils.encryption import create_provider_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
@ -225,7 +225,7 @@ class BuiltinToolManageService:
provider: str, provider: str,
provider_controller: BuiltinToolProviderController, provider_controller: BuiltinToolProviderController,
): ):
encrypter, cache = create_encrypter( encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -396,7 +396,7 @@ class BuiltinToolManageService:
""" """
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
@ -608,7 +608,7 @@ class BuiltinToolManageService:
session.add(custom_client_params) session.add(custom_client_params)
if client_params is not None: if client_params is not None:
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
@ -647,7 +647,7 @@ class BuiltinToolManageService:
if not isinstance(provider_controller, BuiltinToolProviderController): if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider") raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),

@ -21,7 +21,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -115,9 +115,7 @@ class ToolTransformService:
schema = { schema = {
x.to_basic_provider_config().name: x x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type( for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
if db_provider
else CredentialType.API_KEY
) )
} }
@ -136,7 +134,7 @@ class ToolTransformService:
credentials = db_provider.credentials credentials = db_provider.credentials
# init tool configuration # init tool configuration
encrypter, _ = create_encrypter( encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
@ -289,11 +287,9 @@ class ToolTransformService:
if decrypt_credentials: if decrypt_credentials:
# init tool configuration # init tool configuration
encrypter, _ = create_generic_encrypter( encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], controller=provider_controller,
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials

Loading…
Cancel
Save