|
|
|
|
@ -2,7 +2,7 @@ import json
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
@ -21,6 +21,7 @@ from core.tools.entities.api_entities import (
|
|
|
|
|
)
|
|
|
|
|
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
|
|
|
|
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
|
|
|
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
|
|
|
|
from core.tools.tool_label_manager import ToolLabelManager
|
|
|
|
|
from core.tools.tool_manager import ToolManager
|
|
|
|
|
from core.tools.utils.configuration import create_encrypter
|
|
|
|
|
@ -41,7 +42,12 @@ class BuiltinToolManageService:
|
|
|
|
|
get builtin tool provider oauth client schema
|
|
|
|
|
"""
|
|
|
|
|
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
|
|
|
|
return provider.get_oauth_client_schema()
|
|
|
|
|
return {
|
|
|
|
|
"schema": provider.get_oauth_client_schema(),
|
|
|
|
|
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
|
|
|
|
|
tenant_id, provider_name
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
|
|
|
|
@ -139,7 +145,7 @@ class BuiltinToolManageService:
|
|
|
|
|
|
|
|
|
|
# encrypt credentials
|
|
|
|
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache.delete()
|
|
|
|
|
|
|
|
|
|
# update name if provided
|
|
|
|
|
@ -279,20 +285,16 @@ class BuiltinToolManageService:
|
|
|
|
|
"""
|
|
|
|
|
with db.session.no_autoflush:
|
|
|
|
|
providers = (
|
|
|
|
|
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
|
.filter_by(tenant_id=tenant_id, provider=provider_name)
|
|
|
|
|
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
|
|
|
|
.all()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(providers) == 0:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
default_provider = sorted(
|
|
|
|
|
providers,
|
|
|
|
|
key=lambda p: (
|
|
|
|
|
not getattr(p, "is_default", False),
|
|
|
|
|
getattr(p, "created_at", None) or 0,
|
|
|
|
|
),
|
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
|
|
default_provider = providers[0]
|
|
|
|
|
default_provider.is_default = True
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
|
|
|
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
@ -319,6 +321,7 @@ class BuiltinToolManageService:
|
|
|
|
|
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
|
|
|
|
|
credential_info = ToolProviderCredentialInfoApiEntity(
|
|
|
|
|
supported_credential_types=supported_credential_types,
|
|
|
|
|
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
|
|
|
|
|
credentials=credentials,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -368,30 +371,61 @@ class BuiltinToolManageService:
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builtin_tool_oauth_client(
|
|
|
|
|
tenant_id: str, provider: str, plugin_id: str
|
|
|
|
|
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
|
|
|
|
|
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
check if oauth custom client is enabled
|
|
|
|
|
"""
|
|
|
|
|
tool_provider = ToolProviderID(provider)
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
user_client: ToolOAuthTenantClient | None = (
|
|
|
|
|
session.query(ToolOAuthTenantClient)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
enabled=True,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
return user_client is not None and user_client.enabled
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
|
|
|
|
|
"""
|
|
|
|
|
get builtin tool provider
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
user_client = (
|
|
|
|
|
tool_provider = ToolProviderID(provider)
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
encrypter, _ = create_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
user_client: ToolOAuthTenantClient | None = (
|
|
|
|
|
session.query(ToolOAuthTenantClient)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
plugin_id=plugin_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
enabled=True,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
oauth_params: dict[str, Any] | None = None
|
|
|
|
|
if user_client:
|
|
|
|
|
return user_client
|
|
|
|
|
oauth_params = encrypter.decrypt(user_client.oauth_params)
|
|
|
|
|
return oauth_params
|
|
|
|
|
|
|
|
|
|
system_client: ToolOAuthSystemClient | None = (
|
|
|
|
|
session.query(ToolOAuthSystemClient)
|
|
|
|
|
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if system_client:
|
|
|
|
|
oauth_params = encrypter.decrypt(system_client.oauth_params)
|
|
|
|
|
|
|
|
|
|
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
|
|
|
|
|
if system_client is None:
|
|
|
|
|
raise ValueError("no oauth available client config found for this tool provider")
|
|
|
|
|
return system_client
|
|
|
|
|
return oauth_params
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builtin_tool_provider_icon(provider: str):
|
|
|
|
|
@ -533,12 +567,79 @@ class BuiltinToolManageService:
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
|
|
|
|
|
def save_custom_oauth_client_params(
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
provider: str,
|
|
|
|
|
client_params: Optional[dict] = None,
|
|
|
|
|
enable_oauth_custom_client: Optional[bool] = None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
setup oauth custom client
|
|
|
|
|
"""
|
|
|
|
|
if client_params is None and enable_oauth_custom_client is None:
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
tool_provider = ToolProviderID(provider)
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller:
|
|
|
|
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
|
|
|
|
|
|
|
|
|
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
|
|
|
|
|
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
|
|
|
|
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
custom_client_params = (
|
|
|
|
|
session.query(ToolOAuthTenantClient)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# if the record does not exist, create a basic record
|
|
|
|
|
if custom_client_params is None:
|
|
|
|
|
custom_client_params = ToolOAuthTenantClient(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
)
|
|
|
|
|
session.add(custom_client_params)
|
|
|
|
|
|
|
|
|
|
if client_params is not None:
|
|
|
|
|
encrypter, _ = create_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params))
|
|
|
|
|
|
|
|
|
|
if enable_oauth_custom_client is not None:
|
|
|
|
|
custom_client_params.enabled = enable_oauth_custom_client
|
|
|
|
|
|
|
|
|
|
session.commit()
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_custom_oauth_client_params(tenant_id: str, provider: str):
|
|
|
|
|
"""
|
|
|
|
|
get custom oauth client params
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
tool_provider = ToolProviderID(provider)
|
|
|
|
|
custom_oauth_client_params: ToolOAuthTenantClient | None = (
|
|
|
|
|
session.query(ToolOAuthTenantClient)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if custom_oauth_client_params is None:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller:
|
|
|
|
|
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
|
|
|
|
@ -551,17 +652,4 @@ class BuiltinToolManageService:
|
|
|
|
|
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# encrypt credentials
|
|
|
|
|
encrypted_credentials = encrypter.encrypt(client_params)
|
|
|
|
|
session.add(
|
|
|
|
|
ToolOAuthTenantClient(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
plugin_id=tool_provider.plugin_id,
|
|
|
|
|
provider=tool_provider.provider_name,
|
|
|
|
|
enabled=True,
|
|
|
|
|
encrypted_oauth_params=json.dumps(encrypted_credentials),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
session.commit()
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
|
|
|
|
|
|