feat(oauth): enhance OAuth client handling and add custom client support

pull/22036/head
Harry 7 months ago
parent 6ef1e017df
commit 988a76066d

@ -675,18 +675,17 @@ class ToolPluginOAuthApi(Resource):
raise Forbidden()
tenant_id = user.current_tenant_id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
oauth_client_params = BuiltinToolManageService.get_oauth_client(
tenant_id=tenant_id,
provider=provider_name,
plugin_id=plugin_id,
provider=provider
)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
# TODO decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
@ -694,7 +693,7 @@ class ToolPluginOAuthApi(Resource):
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_params,
system_credentials=oauth_client_params,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
@ -724,12 +723,10 @@ class ToolOAuthCallback(Resource):
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client(
tenant_id=tenant_id,
provider=provider_name,
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id=tenant_id,
@ -737,7 +734,7 @@ class ToolOAuthCallback(Resource):
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_params,
system_credentials=oauth_client_params,
request=request,
).credentials
@ -774,7 +771,8 @@ class ToolOAuthCustomClient(Resource):
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json")
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
@ -782,18 +780,21 @@ class ToolOAuthCustomClient(Resource):
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.setup_oauth_custom_client(
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider=provider,
client_params=args["client_params"],
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=current_user.current_tenant_id, provider_name=provider
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)

@ -85,4 +85,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

@ -2,7 +2,7 @@ import json
import logging
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional
from sqlalchemy.orm import Session
@ -21,6 +21,7 @@ from core.tools.entities.api_entities import (
)
from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import create_encrypter
@ -41,7 +42,12 @@ class BuiltinToolManageService:
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return provider.get_oauth_client_schema()
return {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
),
}
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
@ -139,7 +145,7 @@ class BuiltinToolManageService:
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete()
# update name if provided
@ -279,20 +285,16 @@ class BuiltinToolManageService:
"""
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if len(providers) == 0:
return []
default_provider = sorted(
providers,
key=lambda p: (
not getattr(p, "is_default", False),
getattr(p, "created_at", None) or 0,
),
)[0]
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
@ -319,6 +321,7 @@ class BuiltinToolManageService:
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
)
@ -368,30 +371,61 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_oauth_client(
tenant_id: str, provider: str, plugin_id: str
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
"""
get builtin tool provider
"""
with Session(db.engine) as session:
user_client = (
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: dict[str, Any] | None = None
if user_client:
return user_client
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
oauth_params = encrypter.decrypt(system_client.oauth_params)
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if system_client is None:
raise ValueError("no oauth available client config found for this tool provider")
return system_client
return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@ -533,12 +567,79 @@ class BuiltinToolManageService:
)
@staticmethod
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
@ -551,17 +652,4 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
# encrypt credentials
encrypted_credentials = encrypter.encrypt(client_params)
session.add(
ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
enabled=True,
encrypted_oauth_params=json.dumps(encrypted_credentials),
)
)
session.commit()
return {"result": "success"}
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

Loading…
Cancel
Save