|
|
|
@ -4,7 +4,6 @@ import re
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Optional, Union
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import ColumnExpressionArgument
|
|
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
from configs import dify_config
|
|
|
|
@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|
|
|
from core.plugin.entities.plugin import ToolProviderID
|
|
|
|
from core.plugin.entities.plugin import ToolProviderID
|
|
|
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
|
|
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
|
|
|
from core.tools.__base.tool_provider import ToolProviderController
|
|
|
|
from core.tools.__base.tool_provider import ToolProviderController
|
|
|
|
|
|
|
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|
|
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
|
|
|
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
|
|
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
|
|
|
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
|
|
|
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
|
|
|
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
|
|
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
|
|
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
|
|
|
|
|
|
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
|
|
from core.tools.tool_label_manager import ToolLabelManager
|
|
|
|
from core.tools.tool_label_manager import ToolLabelManager
|
|
|
|
from core.tools.tool_manager import ToolManager
|
|
|
|
from core.tools.tool_manager import ToolManager
|
|
|
|
from core.tools.utils.configuration import ProviderConfigEncrypter
|
|
|
|
from core.tools.utils.configuration import ProviderConfigEncrypter
|
|
|
|
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BuiltinToolManageService:
|
|
|
|
class BuiltinToolManageService:
|
|
|
|
|
|
|
|
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
|
|
|
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@ -42,22 +45,11 @@ class BuiltinToolManageService:
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
tools = provider_controller.get_tools()
|
|
|
|
tools = provider_controller.get_tools()
|
|
|
|
|
|
|
|
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
|
|
|
|
# check if user has added the provider
|
|
|
|
|
|
|
|
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
credentials = {}
|
|
|
|
|
|
|
|
if builtin_provider is not None:
|
|
|
|
|
|
|
|
# get credentials
|
|
|
|
|
|
|
|
credentials = builtin_provider.credentials
|
|
|
|
|
|
|
|
credentials = tool_configuration.decrypt(credentials)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result: list[ToolApiEntity] = []
|
|
|
|
result: list[ToolApiEntity] = []
|
|
|
|
for tool in tools or []:
|
|
|
|
for tool in tools or []:
|
|
|
|
result.append(
|
|
|
|
result.append(
|
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
|
|
|
tool=tool,
|
|
|
|
tool=tool,
|
|
|
|
credentials=credentials,
|
|
|
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -73,7 +65,7 @@ class BuiltinToolManageService:
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
# check if user has added the provider
|
|
|
|
# check if user has added the provider
|
|
|
|
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
|
|
|
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
|
|
|
|
|
|
|
credentials = {}
|
|
|
|
credentials = {}
|
|
|
|
if builtin_provider is not None:
|
|
|
|
if builtin_provider is not None:
|
|
|
|
@ -92,16 +84,19 @@ class BuiltinToolManageService:
|
|
|
|
return entity
|
|
|
|
return entity
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
|
|
|
|
def list_builtin_provider_credentials_schema(
|
|
|
|
|
|
|
|
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str
|
|
|
|
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
list builtin provider credentials schema
|
|
|
|
list builtin provider credentials schema
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
:param credential_type: credential type
|
|
|
|
:param provider_name: the name of the provider
|
|
|
|
:param provider_name: the name of the provider
|
|
|
|
:param tenant_id: the id of the tenant
|
|
|
|
:param tenant_id: the id of the tenant
|
|
|
|
:return: the list of tool providers
|
|
|
|
:return: the list of tool providers
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
return jsonable_encoder(provider.get_credentials_schema())
|
|
|
|
return jsonable_encoder(provider.get_credentials_schema(credential_type))
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def update_builtin_tool_provider(
|
|
|
|
def update_builtin_tool_provider(
|
|
|
|
@ -111,11 +106,11 @@ class BuiltinToolManageService:
|
|
|
|
update builtin tool provider
|
|
|
|
update builtin tool provider
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
# get if the provider exists
|
|
|
|
# get if the provider exists
|
|
|
|
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
|
|
|
if provider is None:
|
|
|
|
raise ValueError(f"you have not added provider {provider_name}")
|
|
|
|
raise ValueError(f"you have not added provider {provider_name}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
|
|
|
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
@ -133,10 +128,12 @@ class BuiltinToolManageService:
|
|
|
|
if key in masked_credentials and value == masked_credentials[key]:
|
|
|
|
if key in masked_credentials and value == masked_credentials[key]:
|
|
|
|
credentials[key] = original_credentials[key]
|
|
|
|
credentials[key] = original_credentials[key]
|
|
|
|
|
|
|
|
|
|
|
|
# Encrypt and save the credentials
|
|
|
|
provider_controller.validate_credentials(user_id, credentials)
|
|
|
|
BuiltinToolManageService._encrypt_and_save_credentials(
|
|
|
|
|
|
|
|
provider_controller, tool_configuration, provider, credentials, user_id
|
|
|
|
# encrypt credentials
|
|
|
|
)
|
|
|
|
encrypted_credentials = tool_configuration.encrypt(credentials)
|
|
|
|
|
|
|
|
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
|
|
|
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
|
|
|
|
|
|
|
|
# update name if provided
|
|
|
|
# update name if provided
|
|
|
|
if name is not None and provider.name != name:
|
|
|
|
if name is not None and provider.name != name:
|
|
|
|
@ -158,68 +155,84 @@ class BuiltinToolManageService:
|
|
|
|
user_id: str,
|
|
|
|
user_id: str,
|
|
|
|
api_type: ToolProviderCredentialType,
|
|
|
|
api_type: ToolProviderCredentialType,
|
|
|
|
tenant_id: str,
|
|
|
|
tenant_id: str,
|
|
|
|
provider_name: str,
|
|
|
|
provider: str,
|
|
|
|
credentials: dict,
|
|
|
|
credentials: dict,
|
|
|
|
name: str | None = None,
|
|
|
|
name: str | None = None,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
add builtin tool provider
|
|
|
|
add builtin tool provider
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
|
|
|
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
|
|
|
with redis_client.lock(lock, timeout=20):
|
|
|
|
with redis_client.lock(lock, timeout=20):
|
|
|
|
if name is None:
|
|
|
|
# check if the provider count is over the limit
|
|
|
|
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
|
|
|
|
provider_count = (
|
|
|
|
|
|
|
|
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
|
|
|
|
|
|
|
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO should we get name from oauth authentication?
|
|
|
|
|
|
|
|
name = (
|
|
|
|
|
|
|
|
name
|
|
|
|
|
|
|
|
if name
|
|
|
|
|
|
|
|
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
|
|
|
|
|
|
|
tenant_id, provider, credential_type=api_type
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
provider = BuiltinToolProvider(
|
|
|
|
db_provider = BuiltinToolProvider(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
user_id=user_id,
|
|
|
|
user_id=user_id,
|
|
|
|
provider=provider_name,
|
|
|
|
provider=provider,
|
|
|
|
encrypted_credentials=json.dumps(credentials),
|
|
|
|
encrypted_credentials=json.dumps(credentials),
|
|
|
|
credential_type=api_type.value,
|
|
|
|
credential_type=api_type.value,
|
|
|
|
name=name,
|
|
|
|
name=name,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
raise ValueError(f"provider {provider_name} does not need credentials")
|
|
|
|
raise ValueError(f"provider {provider} does not need credentials")
|
|
|
|
|
|
|
|
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
|
|
|
|
|
|
|
|
# Encrypt and save the credentials
|
|
|
|
# Encrypt and save the credentials
|
|
|
|
BuiltinToolManageService._encrypt_and_save_credentials(
|
|
|
|
BuiltinToolManageService._encrypt_and_save_credentials(
|
|
|
|
provider_controller, tool_configuration, provider, credentials, user_id
|
|
|
|
provider_controller=provider_controller,
|
|
|
|
|
|
|
|
tool_configuration=tool_configuration,
|
|
|
|
|
|
|
|
provider=db_provider,
|
|
|
|
|
|
|
|
credentials=credentials,
|
|
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
db.session.add(provider)
|
|
|
|
db.session.add(db_provider)
|
|
|
|
db.session.commit()
|
|
|
|
db.session.commit()
|
|
|
|
return {"result": "success"}
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_next_builtin_tool_provider_name(
|
|
|
|
def generate_builtin_tool_provider_name(
|
|
|
|
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
|
|
|
|
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
|
|
|
) -> str:
|
|
|
|
) -> str:
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
providers = (
|
|
|
|
db_providers = (
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
.filter_by(
|
|
|
|
.filter_by(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
provider=provider_name,
|
|
|
|
provider=provider,
|
|
|
|
credential_type=type.value,
|
|
|
|
credential_type=credential_type.value,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.order_by(BuiltinToolProvider.created_at.desc())
|
|
|
|
.order_by(BuiltinToolProvider.created_at.desc())
|
|
|
|
.limit(10)
|
|
|
|
|
|
|
|
.all()
|
|
|
|
.all()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Get the default name pattern
|
|
|
|
# Get the default name pattern
|
|
|
|
default_pattern = type.get_name()
|
|
|
|
default_pattern = f"{credential_type.get_name()}"
|
|
|
|
|
|
|
|
|
|
|
|
# Find all names that match the default pattern: "{default_pattern} {number}"
|
|
|
|
# Find all names that match the default pattern: "{default_pattern} {number}"
|
|
|
|
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
|
|
|
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
|
|
|
numbers = []
|
|
|
|
numbers = []
|
|
|
|
|
|
|
|
|
|
|
|
for provider in providers:
|
|
|
|
for db_provider in db_providers:
|
|
|
|
if provider.name:
|
|
|
|
if db_provider.name:
|
|
|
|
match = re.match(pattern, provider.name.strip())
|
|
|
|
match = re.match(pattern, db_provider.name.strip())
|
|
|
|
if match:
|
|
|
|
if match:
|
|
|
|
numbers.append(int(match.group(1)))
|
|
|
|
numbers.append(int(match.group(1)))
|
|
|
|
|
|
|
|
|
|
|
|
@ -231,9 +244,9 @@ class BuiltinToolManageService:
|
|
|
|
max_number = max(numbers)
|
|
|
|
max_number = max(numbers)
|
|
|
|
return f"{default_pattern} {max_number + 1}"
|
|
|
|
return f"{default_pattern} {max_number + 1}"
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
|
|
|
|
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
|
|
|
|
# fallback
|
|
|
|
# fallback
|
|
|
|
return f"{type.get_name()} 1"
|
|
|
|
return f"{credential_type.get_name()} 1"
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_builtin_tool_provider_credentials(
|
|
|
|
def get_builtin_tool_provider_credentials(
|
|
|
|
@ -242,31 +255,43 @@ class BuiltinToolManageService:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
get builtin tool provider credentials
|
|
|
|
get builtin tool provider credentials
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
|
|
|
with db.session.no_autoflush:
|
|
|
|
|
|
|
|
providers = (
|
|
|
|
|
|
|
|
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(providers) == 0:
|
|
|
|
if len(providers) == 0:
|
|
|
|
return []
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
|
|
|
|
default_provider = sorted(
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
providers,
|
|
|
|
credentials: list[ToolProviderCredentialApiEntity] = []
|
|
|
|
key=lambda p: (
|
|
|
|
for provider in providers:
|
|
|
|
not getattr(p, "is_default", False),
|
|
|
|
decrypt_credential = tool_configuration.mask_tool_credentials(
|
|
|
|
getattr(p, "created_at", None) or 0,
|
|
|
|
tool_configuration.decrypt(provider.credentials)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
)[0]
|
|
|
|
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
|
|
|
|
|
|
|
provider=provider,
|
|
|
|
default_provider.is_default = True
|
|
|
|
credentials=decrypt_credential,
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
|
|
|
)
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
credentials.append(credential_entity)
|
|
|
|
credentials: list[ToolProviderCredentialApiEntity] = []
|
|
|
|
return credentials
|
|
|
|
for provider in providers:
|
|
|
|
|
|
|
|
decrypt_credential = tool_configuration.mask_tool_credentials(
|
|
|
|
|
|
|
|
tool_configuration.decrypt(provider.credentials)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
|
|
|
|
|
|
|
provider=provider,
|
|
|
|
|
|
|
|
credentials=decrypt_credential,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
credentials.append(credential_entity)
|
|
|
|
|
|
|
|
return credentials
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
|
|
|
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
delete tool provider
|
|
|
|
delete tool provider
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
|
|
|
|
|
|
|
|
if tool_provider is None:
|
|
|
|
if tool_provider is None:
|
|
|
|
raise ValueError(f"you have not added provider {provider_name}")
|
|
|
|
raise ValueError(f"you have not added provider {provider_name}")
|
|
|
|
@ -387,7 +412,6 @@ class BuiltinToolManageService:
|
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
|
|
|
ToolTransformService.convert_tool_entity_to_api_entity(
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
tool=tool,
|
|
|
|
tool=tool,
|
|
|
|
credentials=user_builtin_provider.original_credentials,
|
|
|
|
|
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
|
|
|
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
@ -399,7 +423,7 @@ class BuiltinToolManageService:
|
|
|
|
return BuiltinToolProviderSort.sort(result)
|
|
|
|
return BuiltinToolProviderSort.sort(result)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
provider: Optional[BuiltinToolProvider] = (
|
|
|
|
provider: Optional[BuiltinToolProvider] = (
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
.filter(
|
|
|
|
.filter(
|
|
|
|
@ -411,48 +435,63 @@ class BuiltinToolManageService:
|
|
|
|
return provider
|
|
|
|
return provider
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This method is used to fetch the builtin provider from the database
|
|
|
|
This method is used to fetch the builtin provider from the database
|
|
|
|
1.if the default provider exists, return the default provider
|
|
|
|
1.if the default provider exists, return the default provider
|
|
|
|
2.if the default provider does not exist, return the oldest provider
|
|
|
|
2.if the default provider does not exist, return the oldest provider
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
full_provider_name = provider_name
|
|
|
|
|
|
|
|
provider_id_entity = ToolProviderID(provider_name)
|
|
|
|
|
|
|
|
provider_name = provider_id_entity.provider_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if provider_id_entity.organization != "langgenius":
|
|
|
|
|
|
|
|
provider = (
|
|
|
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
|
|
|
BuiltinToolProvider.provider == full_provider_name,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
.order_by(
|
|
|
|
|
|
|
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
|
|
|
|
|
|
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
.first()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
provider = (
|
|
|
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
|
|
|
.filter(
|
|
|
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
|
|
|
(BuiltinToolProvider.provider == provider_name)
|
|
|
|
|
|
|
|
| (BuiltinToolProvider.provider == full_provider_name),
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
.order_by(
|
|
|
|
|
|
|
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
|
|
|
|
|
|
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
.first()
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
|
|
|
|
if provider is None:
|
|
|
|
return (
|
|
|
|
return None
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
|
|
|
|
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
|
|
|
|
provider.provider = ToolProviderID(provider.provider).to_string()
|
|
|
|
.order_by(
|
|
|
|
return provider
|
|
|
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
|
|
|
except Exception:
|
|
|
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
|
|
|
# it's an old provider without organization
|
|
|
|
)
|
|
|
|
return (
|
|
|
|
.first()
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
)
|
|
|
|
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
|
|
|
|
|
|
|
.order_by(
|
|
|
|
try:
|
|
|
|
BuiltinToolProvider.is_default.desc(), # default=True first
|
|
|
|
full_provider_name = provider_name
|
|
|
|
BuiltinToolProvider.created_at.asc(), # oldest first
|
|
|
|
provider_id_entity = ToolProviderID(provider_name)
|
|
|
|
)
|
|
|
|
provider_name = provider_id_entity.provider_name
|
|
|
|
.first()
|
|
|
|
|
|
|
|
|
|
|
|
if provider_id_entity.organization != "langgenius":
|
|
|
|
|
|
|
|
provider = _query([BuiltinToolProvider.provider == full_provider_name])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
provider = _query(
|
|
|
|
|
|
|
|
[
|
|
|
|
|
|
|
|
(BuiltinToolProvider.provider == provider_name)
|
|
|
|
|
|
|
|
| (BuiltinToolProvider.provider == full_provider_name)
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if provider is None:
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
provider.provider = ToolProviderID(provider.provider).to_string()
|
|
|
|
|
|
|
|
return provider
|
|
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
|
|
# it's an old provider without organization
|
|
|
|
|
|
|
|
return _query([BuiltinToolProvider.provider == provider_name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
|
|
|
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
|
|
|
return ProviderConfigEncrypter(
|
|
|
|
return ProviderConfigEncrypter(
|
|
|
|
@ -463,7 +502,13 @@ class BuiltinToolManageService:
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
|
|
|
|
def _encrypt_and_save_credentials(
|
|
|
|
|
|
|
|
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
|
|
|
|
|
|
|
tool_configuration: ProviderConfigEncrypter,
|
|
|
|
|
|
|
|
provider: BuiltinToolProvider,
|
|
|
|
|
|
|
|
credentials: dict,
|
|
|
|
|
|
|
|
user_id: str,
|
|
|
|
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Validate and encrypt credentials, then save to database
|
|
|
|
Validate and encrypt credentials, then save to database
|
|
|
|
|
|
|
|
|
|
|
|
@ -480,3 +525,25 @@ class BuiltinToolManageService:
|
|
|
|
encrypted_credentials = tool_configuration.encrypt(credentials)
|
|
|
|
encrypted_credentials = tool_configuration.encrypt(credentials)
|
|
|
|
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
|
|
|
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
tool_configuration.delete_tool_credentials_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
setup oauth custom client
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
|
|
|
if not provider_controller:
|
|
|
|
|
|
|
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Validate and encrypt credentials
|
|
|
|
|
|
|
|
BuiltinToolManageService._encrypt_and_save_credentials(
|
|
|
|
|
|
|
|
provider_controller=provider_controller,
|
|
|
|
|
|
|
|
tool_configuration=tool_configuration,
|
|
|
|
|
|
|
|
provider=None, # No need to save in DB
|
|
|
|
|
|
|
|
credentials=client_params,
|
|
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|