feat(oauth): enhance OAuth client handling and add custom client support

pull/22036/head
Harry 11 months ago
parent 6ef1e017df
commit 988a76066d

@ -675,18 +675,17 @@ class ToolPluginOAuthApi(Resource):
raise Forbidden() raise Forbidden()
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( oauth_client_params = BuiltinToolManageService.get_oauth_client(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider_name, provider=provider
plugin_id=plugin_id,
) )
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context( context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
) )
# TODO decrypt oauth params
oauth_params = plugin_oauth_config.oauth_params
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -694,7 +693,7 @@ class ToolPluginOAuthApi(Resource):
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_name, provider=provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_params, system_credentials=oauth_client_params,
) )
response = make_response(jsonable_encoder(authorization_url_response)) response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie( response.set_cookie(
@ -724,12 +723,10 @@ class ToolOAuthCallback(Resource):
user_id, tenant_id = context.get("user_id"), context.get("tenant_id") user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
plugin_oauth_config = BuiltinToolManageService.get_builtin_tool_oauth_client( oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
tenant_id=tenant_id, if oauth_client_params is None:
provider=provider_name, raise Forbidden("no oauth available client config found for this tool provider")
plugin_id=plugin_id,
)
oauth_params = plugin_oauth_config.oauth_params
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials( credentials = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -737,7 +734,7 @@ class ToolOAuthCallback(Resource):
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_name, provider=provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_params, system_credentials=oauth_client_params,
request=request, request=request,
).credentials ).credentials
@ -774,7 +771,8 @@ class ToolOAuthCustomClient(Resource):
@account_initialization_required @account_initialization_required
def post(self, provider): def post(self, provider):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=True, nullable=False, location="json") parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
user = current_user user = current_user
@ -782,18 +780,21 @@ class ToolOAuthCustomClient(Resource):
if not user.is_admin_or_owner: if not user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
return BuiltinToolManageService.setup_oauth_custom_client( return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id, tenant_id=user.current_tenant_id,
provider=provider, provider=provider,
client_params=args["client_params"], client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
) )
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider): def get(self, provider):
return BuiltinToolManageService.get_builtin_tool_provider_credentials( return jsonable_encoder(
tenant_id=current_user.current_tenant_id, provider_name=provider BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
) )

@ -85,4 +85,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
class ToolProviderCredentialInfoApiEntity(BaseModel): class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider") supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

@ -2,7 +2,7 @@ import json
import logging import logging
import re import re
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -21,6 +21,7 @@ from core.tools.entities.api_entities import (
) )
from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import create_encrypter from core.tools.utils.configuration import create_encrypter
@ -41,7 +42,12 @@ class BuiltinToolManageService:
get builtin tool provider oauth client schema get builtin tool provider oauth client schema
""" """
provider = ToolManager.get_builtin_provider(provider_name, tenant_id) provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return provider.get_oauth_client_schema() return {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
),
}
@staticmethod @staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
@ -279,20 +285,16 @@ class BuiltinToolManageService:
""" """
with db.session.no_autoflush: with db.session.no_autoflush:
providers = ( providers = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all() db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
) )
if len(providers) == 0: if len(providers) == 0:
return [] return []
default_provider = sorted( default_provider = providers[0]
providers,
key=lambda p: (
not getattr(p, "is_default", False),
getattr(p, "created_at", None) or 0,
),
)[0]
default_provider.is_default = True default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter( encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
@ -319,6 +321,7 @@ class BuiltinToolManageService:
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity( credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types, supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials, credentials=credentials,
) )
@ -368,30 +371,61 @@ class BuiltinToolManageService:
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
def get_builtin_tool_oauth_client( def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
tenant_id: str, provider: str, plugin_id: str """
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]: check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None:
""" """
get builtin tool provider get builtin tool provider
""" """
with Session(db.engine) as session: tool_provider = ToolProviderID(provider)
user_client = ( provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=tool_provider.provider_name,
plugin_id=plugin_id, plugin_id=tool_provider.plugin_id,
enabled=True, enabled=True,
) )
.first() .first()
) )
oauth_params: dict[str, Any] | None = None
if user_client: if user_client:
return user_client oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
oauth_params = encrypter.decrypt(system_client.oauth_params)
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first() return oauth_params
if system_client is None:
raise ValueError("no oauth available client config found for this tool provider")
return system_client
@staticmethod @staticmethod
def get_builtin_tool_provider_icon(provider: str): def get_builtin_tool_provider_icon(provider: str):
@ -533,12 +567,79 @@ class BuiltinToolManageService:
) )
@staticmethod @staticmethod
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict): def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
""" """
setup oauth custom client setup oauth custom client
""" """
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session: with Session(db.engine) as session:
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller: if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found") raise ToolProviderNotFoundError(f"Provider {provider} not found")
@ -551,17 +652,4 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
# encrypt credentials
encrypted_credentials = encrypter.encrypt(client_params)
session.add(
ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
enabled=True,
encrypted_oauth_params=json.dumps(encrypted_credentials),
)
)
session.commit()
return {"result": "success"}

Loading…
Cancel
Save