feat(oauth): update api

pull/22036/head
Harry 11 months ago
parent 6c9e99b0c6
commit ba843c2691

@ -35,6 +35,7 @@ class ModelProviderListApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type"))
return jsonable_encoder({"data": provider_list}) return jsonable_encoder({"data": provider_list})

@ -371,12 +371,12 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider, credential_type):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, credential_type, tenant_id)
class ToolApiProviderSchemaApi(Resource): class ToolApiProviderSchemaApi(Resource):
@ -789,7 +789,7 @@ api.add_resource(
) )
api.add_resource( api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi, ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema", "/workspaces/current/tool-provider/builtin/<path:provider>/<path:credential_type>/credentials_schema",
) )
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon") api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config from configs import dify_config
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
ApiProviderAuthType, from core.tools.errors import ToolProviderNotFoundError
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ( from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.workflow_as_tool.tool import WorkflowTool from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -64,8 +55,11 @@ class ToolManager:
@classmethod @classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
""" """
get the hardcoded provider get the hardcoded provider
""" """
if len(cls._hardcoded_providers) == 0: if len(cls._hardcoded_providers) == 0:
# init the builtin providers # init the builtin providers
cls.load_hardcoded_providers_cache() cls.load_hardcoded_providers_cache()
@ -109,7 +103,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock()) contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get(): with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get() plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers: if provider in plugin_tool_providers:
return plugin_tool_providers[provider] return plugin_tool_providers[provider]
@ -127,26 +126,8 @@ class ToolManager:
) )
plugin_tool_providers[provider] = controller plugin_tool_providers[provider] = controller
return controller return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
@classmethod @classmethod
def get_tool_runtime( def get_tool_runtime(
cls, cls,
@ -563,6 +544,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name] return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod @classmethod
def list_providers_from_api( def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@ -577,30 +574,13 @@ class ToolManager:
with db.session.no_autoflush: with db.session.no_autoflush:
if "builtin" in filters: if "builtin" in filters:
def get_builtin_providers(tenant_id):
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
builtin_providers = cls.list_builtin_providers(tenant_id) builtin_providers = cls.list_builtin_providers(tenant_id)
# get builtin providers # key: provider name, value: provider
db_builtin_providers = get_builtin_providers(tenant_id) db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
# rewrite db_builtin_providers for provider in cls.list_default_builtin_providers(tenant_id)
for db_provider in db_builtin_providers: }
db_provider.provider = str(ToolProviderID(db_provider.provider))
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# append builtin providers # append builtin providers
for provider in builtin_providers: for provider in builtin_providers:
@ -612,10 +592,9 @@ class ToolManager:
name_func=lambda x: x.identity.name, name_func=lambda x: x.identity.name,
): ):
continue continue
user_provider = ToolTransformService.builtin_provider_to_user_provider( user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider, provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name), db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False, decrypt_credentials=False,
) )
@ -625,7 +604,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers # get db api providers
if "api" in filters: if "api" in filters:
db_api_providers: list[ApiToolProvider] = ( db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()

@ -1,67 +0,0 @@
import secrets
import urllib.parse
from collections.abc import Mapping
from typing import Any
import requests
from dify_plugin import ToolProvider
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
from werkzeug import Request
class GithubProvider(ToolProvider):
_AUTH_URL = "https://github.com/login/oauth/authorize"
_TOKEN_URL = "https://github.com/login/oauth/access_token"
_API_USER_URL = "https://api.github.com/user"
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
"""
Generate the authorization URL for the Github OAuth.
"""
state = secrets.token_urlsafe(16)
params = {
"client_id": system_credentials["client_id"],
"redirect_uri": system_credentials["redirect_uri"],
"scope": system_credentials.get("scope", "read:user"),
"state": state,
# Optionally: allow_signup, login, etc.
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
"""
Exchange code for access_token.
"""
code = request.args.get("code")
state = request.args.get("state")
if not code:
raise ValueError("No code provided")
# Optionally: validate state here
data = {
"client_id": system_credentials["client_id"],
"client_secret": system_credentials["client_secret"],
"code": code,
"redirect_uri": system_credentials["redirect_uri"],
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
response_json = response.json()
access_token = response_json.get("access_token")
if not access_token:
raise ValueError(f"Error in GitHub OAuth: {response_json}")
return {"access_token": access_token}
def _validate_credentials(self, credentials: dict) -> None:
try:
if "access_token" not in credentials or not credentials.get("access_token"):
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
headers = {
"Authorization": f"Bearer {credentials['access_token']}",
"Accept": "application/vnd.github+json",
}
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
if response.status_code != 200:
raise ToolProviderCredentialValidationError(response.json().get("message"))
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

@ -2,6 +2,7 @@ import json
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Optional, Union
from sqlalchemy import ColumnExpressionArgument from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.entities.tool_entities import ToolProviderCredentialType
@ -40,12 +42,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools() tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider # check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@ -53,7 +50,7 @@ class BuiltinToolManageService:
if builtin_provider is not None: if builtin_provider is not None:
# get credentials # get credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials) credentials = tool_configuration.decrypt(credentials)
result: list[ToolApiEntity] = [] result: list[ToolApiEntity] = []
for tool in tools or []: for tool in tools or []:
@ -74,12 +71,7 @@ class BuiltinToolManageService:
get builtin tool provider info get builtin tool provider info
""" """
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider # check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@ -87,7 +79,7 @@ class BuiltinToolManageService:
if builtin_provider is not None: if builtin_provider is not None:
# get credentials # get credentials
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials) credentials = tool_configuration.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider( entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
@ -100,7 +92,7 @@ class BuiltinToolManageService:
return entity return entity
@staticmethod @staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
""" """
list builtin provider credentials schema list builtin provider credentials schema
@ -130,28 +122,21 @@ class BuiltinToolManageService:
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials") raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Decrypt and restore original credentials for masked values # Decrypt and restore original credentials for masked values
original_credentials = tool_configuration.decrypt(provider.credentials) original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential # check if the credential has changed, save the original credential
for name, value in credentials.items(): for key, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore if key in masked_credentials and value == masked_credentials[key]:
credentials[name] = original_credentials[name] # type: ignore credentials[key] = original_credentials[key]
# Encrypt and save the credentials # Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials( BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id provider_controller, tool_configuration, provider, credentials, user_id
) )
else:
raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one")
# update name if provided # update name if provided
if name is not None and provider.name != name: if name is not None and provider.name != name:
@ -180,8 +165,8 @@ class BuiltinToolManageService:
""" """
add builtin tool provider add builtin tool provider
""" """
lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}" lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
with redis_client.lock(lock_name, timeout=20): with redis_client.lock(lock, timeout=20):
if name is None: if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type) name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
@ -198,12 +183,7 @@ class BuiltinToolManageService:
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials") raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Encrypt and save the credentials # Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials( BuiltinToolManageService._encrypt_and_save_credentials(
@ -268,23 +248,17 @@ class BuiltinToolManageService:
return [] return []
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
credentials: list[ToolProviderCredentialApiEntity] = [] credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers: for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials( decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials) tool_configuration.decrypt(provider.credentials)
) )
credentials.append( credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
ToolTransformService.convert_builtin_provider_to_credential_api_entity(
provider=provider, provider=provider,
credentials=decrypt_credential, credentials=decrypt_credential,
) )
) credentials.append(credential_entity)
return credentials return credentials
@staticmethod @staticmethod
@ -292,22 +266,17 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id) tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider_obj is None: if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}") raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider_obj) db.session.delete(tool_provider)
db.session.commit() db.session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter( tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
return {"result": "success"} return {"result": "success"}
@ -334,7 +303,9 @@ class BuiltinToolManageService:
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str): def get_builtin_tool_oauth_client(
tenant_id: str, provider: str, plugin_id: str
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
""" """
get builtin tool provider get builtin tool provider
""" """
@ -350,14 +321,12 @@ class BuiltinToolManageService:
.first() .first()
) )
if user_client: if user_client:
plugin_oauth_config = user_client return user_client
else:
plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if plugin_oauth_config: system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
return plugin_oauth_config if system_client is None:
raise ValueError("no oauth available client config found for this tool provider")
raise ValueError("no oauth available config found for this plugin") return system_client
@staticmethod @staticmethod
def get_builtin_tool_provider_icon(provider: str): def get_builtin_tool_provider_icon(provider: str):
@ -379,9 +348,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush: with db.session.no_autoflush:
# get all user added providers # get all user added providers
db_providers: list[BuiltinToolProvider] = ( db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# rewrite db_providers # rewrite db_providers
for db_provider in db_providers: for db_provider in db_providers:
@ -432,8 +399,8 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@staticmethod @staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None: def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider = ( provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .filter(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
@ -444,14 +411,14 @@ class BuiltinToolManageService:
return provider return provider
@staticmethod @staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
""" """
This method is used to fetch the builtin provider from the database This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider 1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider 2.if the default provider does not exist, return the oldest provider
""" """
def _query(provider_filters: list[ColumnExpressionArgument[bool]]): def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
return ( return (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters) .filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
@ -484,21 +451,16 @@ class BuiltinToolManageService:
return provider return provider
except Exception: except Exception:
# it's an old provider without organization # it's an old provider without organization
provider_obj = _query([BuiltinToolProvider.provider == provider_name]) return _query([BuiltinToolProvider.provider == provider_name])
return provider_obj
@staticmethod @staticmethod
def _decrypt_and_restore_credentials(tool_configuration, provider, credentials): def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
""" return ProviderConfigEncrypter(
Decrypt original credentials and restore masked values from the input credentials tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
:param tool_configuration: the tool configuration encrypter provider_type=provider_controller.provider_type.value,
:param provider: the provider object from database provider_identity=provider_controller.entity.identity.name,
:param credentials: the input credentials from user )
:return: the processed credentials with original values restored
"""
return credentials
@staticmethod @staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id): def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):

@ -307,7 +307,7 @@ class ToolTransformService:
) )
@staticmethod @staticmethod
def convert_builtin_provider_to_credential_api_entity( def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity: ) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity( return ToolProviderCredentialApiEntity(

Loading…
Cancel
Save