Merge branch 'feat/tool-plugin-oauth' into deploy/dev

# Conflicts:
#	api/controllers/console/workspace/tool_providers.py
#	api/core/tools/entities/api_entities.py
#	api/core/tools/tool_manager.py
#	api/core/tools/utils/configuration.py
#	api/services/tools/tools_transform_service.py
pull/22036/head
Harry 11 months ago
commit c160a0e5e3

@ -1,9 +1,13 @@
import io import io
from urllib.parse import urlparse from urllib.parse import urlparse
from flask import make_response, redirect, request, send_file
from flask import redirect, send_file from flask import redirect, send_file
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, reqparse from flask_restful import (
Resource,
reqparse,
)
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -14,10 +18,19 @@ from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError from core.mcp.error import MCPAuthError
from core.mcp.mcp_client import MCPClient from core.mcp.mcp_client import MCPClient
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService
@ -89,7 +102,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id user_id = user.id
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource): class ToolBuiltinProviderDeleteApi(Resource):
@ -98,17 +111,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
user = current_user user = current_user
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider( return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id, tenant_id,
provider, provider,
args["credential_id"],
)
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=args["credentials"],
name=args["name"],
api_type=CredentialType.of(args["type"]),
) )
@ -126,17 +169,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider( result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id, user_id=user_id,
tenant_id=tenant_id, tenant_id=tenant_id,
provider_name=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
credential_id=args["credential_id"],
name=args["name"],
) )
session.commit() session.commit()
return result return result
@ -149,9 +195,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider): def get(self, provider):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials( return jsonable_encoder(
tenant_id=tenant_id, BuiltinToolManageService.get_builtin_tool_provider_credentials(
provider_name=provider, tenant_id=tenant_id,
provider_name=provider,
)
) )
@ -344,12 +392,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider, credential_type):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, CredentialType.of(credential_type), tenant_id
)
)
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@ -586,15 +637,12 @@ class ToolApiListApi(Resource):
@account_initialization_required @account_initialization_required
def get(self): def get(self):
user = current_user user = current_user
user_id = user.id
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
return jsonable_encoder( return jsonable_encoder(
[ [
provider.to_dict() provider.to_dict()
for provider in ApiToolManageService.list_api_tools( for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id, tenant_id,
) )
] ]
@ -631,6 +679,178 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels()) return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(
tenant_id=tenant_id,
provider=provider
)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
)
)
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
)
)
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
class ToolProviderMCPApi(Resource): class ToolProviderMCPApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -791,14 +1011,25 @@ api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# builtin tool provider # builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools") api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource( api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials" ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
) )
api.add_resource( api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema", "/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
) )
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon") api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

@ -0,0 +1,77 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCache(ABC):
"""Base class for provider credentials cache"""
def __init__(self, **kwargs):
self.cache_key = self._generate_cache_key(**kwargs)
@abstractmethod
def _generate_cache_key(self, **kwargs) -> str:
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
try:
cached_credentials = cached_credentials.decode("utf-8")
return dict(json.loads(cached_credentials))
except JSONDecodeError:
return None
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None:
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)
class GenericProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for generic provider credentials"""
def __init__(self, tenant_id: str, identity_id: str):
super().__init__(tenant_id=tenant_id, identity_id=identity_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
identity_id = kwargs["identity_id"]
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
pass
def delete(self) -> None:
"""Delete cached provider credentials"""
pass

@ -1,51 +0,0 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

@ -1,12 +1,12 @@
from core.plugin.entities.request import RequestInvokeEncrypt from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import create_generic_encrypter
from models.account import Tenant from models.account import Tenant
class PluginEncrypter: class PluginEncrypter:
@classmethod @classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter( encrypter, cache = create_generic_encrypter(
tenant_id=tenant.id, tenant_id=tenant.id,
config=payload.config, config=payload.config,
provider_type=payload.namespace, provider_type=payload.namespace,
@ -22,7 +22,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data), "data": encrypter.decrypt(payload.data),
} }
elif payload.opt == "clear": elif payload.opt == "clear":
encrypter.delete_tool_credentials_cache() cache.delete()
return { return {
"data": {}, "data": {},
} }

@ -15,6 +15,7 @@ class OAuthHandler(BasePluginClient):
user_id: str, user_id: str,
plugin_id: str, plugin_id: str,
provider: str, provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse: ) -> PluginOAuthAuthorizationUrlResponse:
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
@ -25,6 +26,7 @@ class OAuthHandler(BasePluginClient):
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": provider, "provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials, "system_credentials": system_credentials,
}, },
}, },
@ -43,6 +45,7 @@ class OAuthHandler(BasePluginClient):
user_id: str, user_id: str,
plugin_id: str, plugin_id: str,
provider: str, provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
request: Request, request: Request,
) -> PluginOAuthCredentialsResponse: ) -> PluginOAuthCredentialsResponse:
@ -61,6 +64,7 @@ class OAuthHandler(BasePluginClient):
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": provider, "provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials, "system_credentials": system_credentials,
# for json serialization # for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(), "raw_http_request": binascii.hexlify(raw_request_bytes).decode(),

@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient): class PluginToolManager(BasePluginClient):
@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str, tool_provider: str,
tool_name: str, tool_name: str,
credentials: dict[str, Any], credentials: dict[str, Any],
credential_type: CredentialType,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name, "provider": tool_provider_id.provider_name,
"tool": tool_name, "tool": tool_name,
"credentials": credentials, "credentials": credentials,
"credential_type": credential_type,
"tool_parameters": tool_parameters, "tool_parameters": tool_parameters,
}, },
}, },

@ -1010,6 +1010,9 @@ class DatasetRetrieval:
def _process_metadata_filter_func( def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
): ):
if value is None:
return
key = f"{metadata_name}_{sequence}" key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value" key_value = f"{metadata_name}_{sequence}_value"
match condition: match condition:

@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel): class ToolRuntime(BaseModel):
@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict) credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[CredentialType] = CredentialType.API_KEY
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import ( from core.tools.errors import (
ToolProviderNotFoundError, ToolProviderNotFoundError,
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict) credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__( super().__init__(
entity=ToolProviderEntity( entity=ToolProviderEntity(
identity=provider_yaml["identity"], identity=provider_yaml["identity"],
credentials_schema=credentials_schema, credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
), ),
) )
@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema :return: the credentials schema
""" """
if not self.entity.credentials_schema: return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
return []
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
return self.entity.credentials_schema.copy() :return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None:
types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None:
types.append(CredentialType.OAUTH2.value)
return types
def get_tools(self) -> list[BuiltinTool]: def get_tools(self) -> list[BuiltinTool]:
""" """
@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials :return: whether the provider needs credentials
""" """
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property @property
def provider_type(self) -> ToolProviderType: def provider_type(self) -> ToolProviderType:

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel): class ToolApiEntity(BaseModel):
@ -85,3 +85,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict: def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise.""" """Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {} return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

@ -353,10 +353,18 @@ class ToolEntity(BaseModel):
return v or [] return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel): class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity identity: ToolProviderIdentity
plugin_id: Optional[str] = None plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list) credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity): class ToolProviderEntityWithPlugin(ToolProviderEntity):
@ -443,3 +451,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]: def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump() return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider, tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name, tool_name=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=app_id, app_id=app_id,

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL from yarl import URL
import contexts import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
@ -24,7 +25,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config from configs import dify_config
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -41,6 +41,7 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
CredentialType,
ToolInvokeFrom, ToolInvokeFrom,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
@ -50,6 +51,8 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import (
ProviderConfigEncrypter, ProviderConfigEncrypter,
ToolParameterConfigurationManager, ToolParameterConfigurationManager,
create_encrypter,
create_generic_encrypter,
) )
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
@ -68,8 +71,11 @@ class ToolManager:
@classmethod @classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
""" """
get the hardcoded provider get the hardcoded provider
""" """
if len(cls._hardcoded_providers) == 0: if len(cls._hardcoded_providers) == 0:
# init the builtin providers # init the builtin providers
cls.load_hardcoded_providers_cache() cls.load_hardcoded_providers_cache()
@ -113,7 +119,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock()) contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get(): with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get() plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers: if provider in plugin_tool_providers:
return plugin_tool_providers[provider] return plugin_tool_providers[provider]
@ -131,25 +142,7 @@ class ToolManager:
) )
plugin_tool_providers[provider] = controller plugin_tool_providers[provider] = controller
return controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod @classmethod
def get_tool_runtime( def get_tool_runtime(
@ -160,6 +153,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
""" """
get the tool runtime get the tool runtime
@ -170,6 +164,7 @@ class ToolManager:
:param tenant_id: the tenant id :param tenant_id: the tenant id
:param invoke_from: invoke from :param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from :param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool :return: the tool
""" """
@ -197,45 +192,59 @@ class ToolManager:
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id) provider_id_entity = ToolProviderID(provider_id)
# get credentials # get credentials
builtin_provider: BuiltinToolProvider | None = ( if credential_id:
db.session.query(BuiltinToolProvider) builtin_provider = (
.filter( db.session.query(BuiltinToolProvider)
BuiltinToolProvider.tenant_id == tenant_id, .filter(
(BuiltinToolProvider.provider == str(provider_id_entity)) BuiltinToolProvider.tenant_id == tenant_id,
| (BuiltinToolProvider.provider == provider_id_entity.provider_name), BuiltinToolProvider.id == credential_id,
)
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
) )
.first()
)
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else: else:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first() .first()
) )
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials encrypter, _ = create_encrypter(
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], config=[
provider_type=provider_controller.provider_type.value, x.to_basic_provider_config()
provider_identity=provider_controller.entity.identity.name, for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
) )
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=decrypted_credentials, credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,
@ -245,22 +254,18 @@ class ToolManager:
elif provider_type == ToolProviderType.API: elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
encrypter, _ = create_generic_encrypter(
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value, provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name, provider_identity=api_provider.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast( return cast(
ApiTool, ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime( api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=decrypted_credentials, credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,
) )
@ -362,6 +367,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW, tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
) )
parameters = tool_runtime.get_merged_runtime_parameters() parameters = tool_runtime.get_merged_runtime_parameters()
@ -551,6 +557,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name] return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod @classmethod
def list_providers_from_api( def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@ -565,21 +587,13 @@ class ToolManager:
with db.session.no_autoflush: with db.session.no_autoflush:
if "builtin" in filters: if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id) builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers # key: provider name, value: provider
db_builtin_providers: list[BuiltinToolProvider] = ( db_builtin_providers = {
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() str(ToolProviderID(provider.provider)): provider
) for provider in cls.list_default_builtin_providers(tenant_id)
}
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# append builtin providers # append builtin providers
for provider in builtin_providers: for provider in builtin_providers:
@ -591,10 +605,9 @@ class ToolManager:
name_func=lambda x: x.identity.name, name_func=lambda x: x.identity.name,
): ):
continue continue
user_provider = ToolTransformService.builtin_provider_to_user_provider( user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider, provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name), db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False, decrypt_credentials=False,
) )
@ -604,7 +617,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers # get db api providers
if "api" in filters: if "api" in filters:
db_api_providers: list[ApiToolProvider] = ( db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@ -750,7 +762,7 @@ class ToolManager:
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
) )
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( tool_configuration = create_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value, provider_type=controller.provider_type.value,

@ -1,12 +1,10 @@
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any, Optional, Protocol
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_cache import GenericProviderCredentialsCache
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ToolParameter, ToolParameter,
@ -14,11 +12,38 @@ from core.tools.entities.tool_entities import (
) )
class ProviderConfigEncrypter(BaseModel): class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str tenant_id: str
config: list[BasicProviderConfig] config: list[BasicProviderConfig]
provider_type: str provider_config_cache: ProviderConfigCache
provider_identity: str
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
""" """
@ -72,21 +97,17 @@ class ProviderConfigEncrypter(BaseModel):
return data return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]: def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, Any]:
""" """
decrypt tool credentials with tenant id decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values return a deep copy of credentials with decrypted values
""" """
if use_cache: if use_cache:
cache = ToolProviderCredentialsCache( cached_credentials = self.provider_config_cache.get()
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials: if cached_credentials:
return cached_credentials return cached_credentials
data = self._deep_copy(data) data = self._deep_copy(data)
# get fields need to be decrypted # get fields need to be decrypted
fields = dict[str, BasicProviderConfig]() fields = dict[str, BasicProviderConfig]()
@ -104,18 +125,25 @@ class ProviderConfigEncrypter(BaseModel):
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception: except Exception:
pass pass
if use_cache: if use_cache:
cache.set(data) self.provider_config_cache.set(data)
return data return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache( def create_encrypter(
tenant_id=self.tenant_id, tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
identity_id=f"{self.provider_type}.{self.provider_identity}", ):
cache_type=ToolProviderCredentialsCacheType.PROVIDER, return ProviderConfigEncrypter(
) tenant_id=tenant_id, config=config, provider_config_cache=cache
cache.delete() ), cache
def create_generic_encrypter(
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
):
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
return encrypt, cache
class ToolParameterConfigurationManager: class ToolParameterConfigurationManager:

@ -490,6 +490,9 @@ class KnowledgeRetrievalNode(LLMNode):
def _process_metadata_filter_func( def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
): ):
if value is None:
return
key = f"{metadata_name}_{sequence}" key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value" key_value = f"{metadata_name}_{sequence}_value"
match condition: match condition:

@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy
tool_configurations: dict[str, Any] tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before") @field_validator("tool_configurations", mode="before")

@ -0,0 +1,65 @@
"""tool oauth
Revision ID: 71f5020c6470
Revises: 4474872b0ee6
Create Date: 2025-06-24 17:05:43.118647
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '71f5020c6470'
down_revision = '4474872b0ee6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_tenant_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique')
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api_key'::character varying"), nullable=False))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.drop_column('credential_type')
batch_op.drop_column('is_default')
batch_op.drop_column('name')
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id'])
op.drop_table('tool_oauth_tenant_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

@ -0,0 +1,25 @@
"""merge tool oauth and remove sequence number branches
Revision ID: 46d46b3f389c
Revises: 0ab65e1cc7fa, 71f5020c6470
Create Date: 2025-06-25 11:01:55.215896
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '46d46b3f389c'
down_revision = ('0ab65e1cc7fa', '71f5020c6470')
branch_labels = None
depends_on = None
def upgrade():
pass
def downgrade():
pass

@ -21,6 +21,47 @@ from .model import Account, App, Tenant
from .types import StringUUID from .types import StringUUID
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params))
class BuiltinToolProvider(Base): class BuiltinToolProvider(Base):
""" """
This table stores the tool provider information for built-in tools for each tenant. This table stores the tool provider information for built-in tools for each tenant.
@ -29,12 +70,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers" __tablename__ = "tool_builtin_providers"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
) )
# id of the tool provider # id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
)
# id of the tenant # id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider # who created this tool provider
@ -49,6 +92,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
) )
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:
@ -61,14 +109,11 @@ class ApiToolProvider(Base):
""" """
__tablename__ = "tool_api_providers" __tablename__ = "tool_api_providers"
__table_args__ = ( __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),)
db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider # name of the api provider
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon # icon
icon = db.Column(db.String(255), nullable=False) icon = db.Column(db.String(255), nullable=False)
# original schema # original schema

@ -582,6 +582,11 @@ class AppDslService:
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids for dataset_id in dataset_ids
] ]
# filter credential id from tool node
if node.get("data", {}).get("type", "") == NodeType.TOOL.value:
node["data"]["credential_id"] = None
export_data["workflow"] = workflow_dict export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow) dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [ export_data["dependencies"] = [

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
) )
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider from models.tools import ApiToolProvider
@ -297,28 +297,28 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles) provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists # get original credentials if exists
tool_configuration = ProviderConfigEncrypter( encrypter, cache = create_generic_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
original_credentials = tool_configuration.decrypt(provider.credentials) original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
for name, value in credentials.items(): for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name] credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials) credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials) provider.credentials_str = json.dumps(credentials)
db.session.add(provider) db.session.add(provider)
db.session.commit() db.session.commit()
# delete cache # delete cache
tool_configuration.delete_tool_credentials_cache() cache.delete()
# update labels # update labels
ToolLabelManager.update_tool_labels(provider_controller, labels) ToolLabelManager.update_tool_labels(provider_controller, labels)
@ -416,15 +416,15 @@ class ApiToolManageService:
# decrypt credentials # decrypt credentials
if db_provider.id: if db_provider.id:
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_generic_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()), config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
decrypted_credentials = tool_configuration.decrypt(credentials) decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items(): for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name] credentials[name] = decrypted_credentials[name]
@ -446,7 +446,7 @@ class ApiToolManageService:
return {"result": result or "empty response"} return {"result": result or "empty response"}
@staticmethod @staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
""" """
list api tools list api tools
""" """
@ -474,7 +474,7 @@ class ApiToolManageService:
for tool in tools or []: for tool in tools or []:
user_provider.tools.append( user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity( ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels tenant_id=tenant_id, tool=tool, labels=labels
) )
) )

@ -1,28 +1,54 @@
import json import json
import logging import logging
import re
from pathlib import Path from pathlib import Path
from typing import Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import (
ToolApiEntity,
ToolProviderApiEntity,
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import create_encrypter
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import BuiltinToolProvider from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BuiltinToolManageService: class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
"""
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
),
}
@staticmethod @staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
""" """
@ -36,27 +62,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
result: list[ToolApiEntity] = [] result: list[ToolApiEntity] = []
for tool in tools or []: for tool in tools or []:
result.append( result.append(
ToolTransformService.convert_tool_entity_to_api_entity( ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool, tool=tool,
credentials=credentials,
tenant_id=tenant_id, tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller), labels=ToolLabelManager.get_tool_labels(provider_controller),
) )
@ -65,25 +75,15 @@ class BuiltinToolManageService:
return result return result
@staticmethod @staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): def get_builtin_tool_provider_info(tenant_id: str, provider: str):
""" """
get builtin tool provider info get builtin tool provider info
""" """
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider # check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
if builtin_provider is None:
credentials = {} raise ValueError(f"you have not added provider {provider}")
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider( entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
@ -92,55 +92,67 @@ class BuiltinToolManageService:
) )
entity.original_credentials = {} entity.original_credentials = {}
return entity return entity
@staticmethod @staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): def list_builtin_provider_credentials_schema(
provider_name: str, credential_type: CredentialType, tenant_id: str
):
""" """
list builtin provider credentials schema list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider :param provider_name: the name of the provider
:param tenant_id: the id of the tenant :param tenant_id: the id of the tenant
:return: the list of tool providers :return: the list of tool providers
""" """
provider = ToolManager.get_builtin_provider(provider_name, tenant_id) provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema()) return provider.get_credentials_schema_by_type(credential_type)
@staticmethod @staticmethod
def update_builtin_tool_provider( def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None
): ):
""" """
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists # get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try: try:
# get provider if CredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials") raise ValueError(f"provider {provider} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], tenant_id, db_provider, provider, provider_controller
provider_type=provider_controller.provider_type.value, )
provider_identity=provider_controller.entity.identity.name,
) # Decrypt and restore original credentials for masked values
original_credentials = encrypter.decrypt(db_provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
for name, value in credentials.items(): for key, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: if key in masked_credentials and value == masked_credentials[key]:
credentials[name] = original_credentials[name] credentials[key] = original_credentials[key]
# validate credentials
provider_controller.validate_credentials(user_id, credentials) provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials) # encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete()
# update name if provided
if name is not None and db_provider.name != name:
db_provider.name = name
db.session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
@ -149,71 +161,272 @@ class BuiltinToolManageService:
) as e: ) as e:
raise ValueError(str(e)) raise ValueError(str(e))
if provider is None: return {"result": "success"}
# create provider
provider = BuiltinToolProvider( @staticmethod
def add_builtin_tool_provider(
user_id: str,
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
# check if the provider count is over the limit
provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# TODO should we get name from oauth authentication?
name = (
name
if name
else BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type
)
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
provider=provider_name, provider=provider,
encrypted_credentials=json.dumps(credentials), encrypted_credentials=json.dumps(credentials),
credential_type=api_type.value,
name=name,
) )
db.session.add(provider) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
else: if not provider_controller.need_credentials:
provider.encrypted_credentials = json.dumps(credentials) raise ValueError(f"provider {provider} does not need credentials")
# delete cache encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tool_configuration.delete_tool_credentials_cache() tenant_id, db_provider, provider, provider_controller
)
db.session.commit() # encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete()
db.session.add(db_provider)
db.session.commit()
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
db.session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
""" """
get builtin tool provider credentials get builtin tool provider credentials
""" """
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if provider_obj is None: if len(providers) == 0:
return {} return []
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) default_provider = providers[0]
tool_configuration = ProviderConfigEncrypter( default_provider.is_default = True
tenant_id=tenant_id, provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
provider_type=provider_controller.provider_type.value, tenant_id, default_provider, default_provider.provider, provider_controller
provider_identity=provider_controller.entity.identity.name, )
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
"""
get builtin tool provider credential info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
supported_credential_types = provider_controller.get_supported_credential_types()
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
) )
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials) return credential_info
return credentials
@staticmethod @staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
""" """
delete tool provider delete tool provider
""" """
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if provider_obj is None: if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}") raise ValueError(f"you have not added provider {provider}")
db.session.delete(provider_obj) db.session.delete(tool_provider)
db.session.commit() db.session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_configuration = ProviderConfigEncrypter( _, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id=tenant_id, tenant_id, tool_provider, provider, provider_controller
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
) )
tool_configuration.delete_tool_credentials_cache() cache.delete()
return {"result": "success"} return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, default=True
).update({"default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
"""
get builtin tool provider
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: dict[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
oauth_params = encrypter.decrypt(system_client.oauth_params)
return oauth_params
@staticmethod @staticmethod
def get_builtin_tool_provider_icon(provider: str): def get_builtin_tool_provider_icon(provider: str):
""" """
@ -234,9 +447,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush: with db.session.no_autoflush:
# get all user added providers # get all user added providers
db_providers: list[BuiltinToolProvider] = ( db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# rewrite db_providers # rewrite db_providers
for db_provider in db_providers: for db_provider in db_providers:
@ -275,7 +486,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity( ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tenant_id=tenant_id,
tool=tool, tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller), labels=ToolLabelManager.get_tool_labels(provider_controller),
) )
) )
@ -287,43 +497,159 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@staticmethod @staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
try: provider: Optional[BuiltinToolProvider] = (
full_provider_name = provider_name db.session.query(BuiltinToolProvider)
provider_id_entity = ToolProviderID(provider_name) .filter(
provider_name = provider_id_entity.provider_name BuiltinToolProvider.tenant_id == tenant_id,
if provider_id_entity.organization != "langgenius": BuiltinToolProvider.id == credential_id,
provider_obj = ( )
db.session.query(BuiltinToolProvider) .first()
.filter( )
BuiltinToolProvider.tenant_id == tenant_id, return provider
BuiltinToolProvider.provider == full_provider_name,
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
) )
.first() else:
) provider = (
else: session.query(BuiltinToolProvider)
provider_obj = ( .filter(
db.session.query(BuiltinToolProvider) BuiltinToolProvider.tenant_id == tenant_id,
.filter( (BuiltinToolProvider.provider == provider_name)
BuiltinToolProvider.tenant_id == tenant_id, | (BuiltinToolProvider.provider == full_provider_name),
(BuiltinToolProvider.provider == provider_name) )
| (BuiltinToolProvider.provider == full_provider_name), .order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
) )
.first() .first()
) )
if provider_obj is None: @staticmethod
return None def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() tool_provider = ToolProviderID(provider)
return provider_obj provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
except Exception: if not provider_controller:
# it's an old provider without organization raise ToolProviderNotFoundError(f"Provider {provider} not found")
return (
db.session.query(BuiltinToolProvider) if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
.filter( raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name), with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
) )
.first() .first()
) )
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

@ -6,20 +6,22 @@ from yarl import URL
from configs import dify_config from configs import dify_config
from core.mcp.types import Tool as MCPTool from core.mcp.types import Tool as MCPTool
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.tools.__base.tool import Tool from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
CredentialType,
ToolParameter, ToolParameter,
ToolProviderType, ToolProviderType,
) )
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import create_encrypter, create_generic_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -110,7 +112,14 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema # get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
if db_provider
else CredentialType.API_KEY
)
}
for name, value in schema.items(): for name, value in schema.items():
if result.masked_credentials: if result.masked_credentials:
@ -127,15 +136,23 @@ class ToolTransformService:
credentials = db_provider.credentials credentials = db_provider.credentials
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], config=[
provider_type=provider_controller.provider_type.value, x.to_basic_provider_config()
provider_identity=provider_controller.entity.identity.name, for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials) decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials result.original_credentials = decrypted_credentials
@ -272,7 +289,7 @@ class ToolTransformService:
if decrypt_credentials: if decrypt_credentials:
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( encrypter, _ = create_generic_encrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
@ -280,8 +297,8 @@ class ToolTransformService:
) )
# decrypt the credentials and mask the credentials # decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials) decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials result.masked_credentials = masked_credentials
@ -291,7 +308,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity( def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool], tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str, tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None, labels: list[str] | None = None,
) -> ToolApiEntity: ) -> ToolApiEntity:
""" """
@ -301,7 +317,7 @@ class ToolTransformService:
# fork tool runtime # fork tool runtime
tool = tool.fork_tool_runtime( tool = tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
credentials=credentials or {}, credentials={},
tenant_id=tenant_id, tenant_id=tenant_id,
) )
) )
@ -342,6 +358,19 @@ class ToolTransformService:
labels=labels or [], labels=labels or [],
) )
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)
@staticmethod @staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
""" """

Loading…
Cancel
Save