feat(oauth&mcp): refactor credential encrypter

pull/22036/head
Harry 11 months ago
parent c160a0e5e3
commit 478c156f7d

@ -2,7 +2,6 @@ import io
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file from flask import make_response, redirect, request, send_file
from flask import redirect, send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import ( from flask_restful import (
Resource, Resource,
@ -18,11 +17,6 @@ from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
@ -97,10 +91,7 @@ class ToolBuiltinProviderInfoApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
user = current_user tenant_id = current_user.current_tenant_id
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
@ -695,10 +686,7 @@ class ToolPluginOAuthApi(Resource):
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client( oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
tenant_id=tenant_id,
provider=provider
)
if oauth_client_params is None: if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
@ -851,6 +839,7 @@ api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client") api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@setup_required @setup_required
@login_required @login_required

@ -1,5 +1,5 @@
from core.plugin.entities.request import RequestInvokeEncrypt from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import create_generic_encrypter from core.tools.utils.encryption import create_generic_encrypter
from models.account import Tenant from models.account import Tenant

@ -46,14 +46,12 @@ from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
create_encrypter,
create_generic_encrypter,
) )
from core.tools.utils.encryption import create_encrypter, create_generic_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -762,15 +760,15 @@ class ToolManager:
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
) )
# init tool configuration # init tool configuration
tool_configuration = create_encrypter( encrypter, _ = create_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value, cache=ToolProviderCredentialsCache(
provider_identity=controller.entity.identity.name, tenant_id=tenant_id, provider=provider, credential_id=provider_obj.id
),
) )
decrypted_credentials = tool_configuration.decrypt(credentials) masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
try: try:
icon = json.loads(provider_obj.icon) icon = json.loads(provider_obj.icon)

@ -1,9 +1,7 @@
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional, Protocol from typing import Any
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
@ -12,140 +10,6 @@ from core.tools.entities.tool_entities import (
) )
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
if use_cache:
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
if use_cache:
self.provider_config_cache.set(data)
return data
def create_encrypter(
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
):
return ProviderConfigEncrypter(
tenant_id=tenant_id, config=config, provider_config_cache=cache
), cache
def create_generic_encrypter(
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
):
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
return encrypt, cache
class ToolParameterConfigurationManager: class ToolParameterConfigurationManager:
""" """
Tool parameter configuration manager Tool parameter configuration manager

@ -0,0 +1,134 @@
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
self.provider_config_cache.set(data)
return data
def create_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_generic_encrypter(
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
):
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
return encrypt, cache

@ -10,10 +10,12 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.types import Tool from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from core.tools.utils.encryption import create_encrypter
from models.base import Base from models.base import Base
from .engine import db from .engine import db
@ -327,17 +329,14 @@ class MCPToolProvider(Base):
@property @property
def decrypted_credentials(self) -> dict: def decrypted_credentials(self) -> dict:
from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
provider_controller = MCPToolProviderController._from_db(self) provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter( return create_encrypter(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, cache=NoOpProviderCredentialCache(),
provider_identity=provider_controller.provider_id, )[0].decrypt(self.credentials)
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
class ToolModelInvoke(Base): class ToolModelInvoke(Base):

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
) )
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter from core.tools.utils.encryption import ProviderConfigEncrypter, create_generic_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider from models.tools import ApiToolProvider

@ -24,7 +24,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import create_encrypter from core.tools.utils.encryption import create_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient

@ -12,7 +12,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import MCPToolProvider from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService

@ -5,8 +5,8 @@ from typing import Any, Optional, Union, cast
from yarl import URL from yarl import URL
from configs import dify_config from configs import dify_config
from core.mcp.types import Tool as MCPTool
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -21,7 +21,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType, ToolProviderType,
) )
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import create_encrypter, create_generic_encrypter from core.tools.utils.encryption import create_encrypter, create_generic_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider

Loading…
Cancel
Save