feat: refactor OAuth provider handling and improve provider name generation

feat/rag-2
Harry 10 months ago
parent 9f2a9ad271
commit 0ac5c0bf3e

@ -1,6 +1,5 @@
import random from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask import redirect, request
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore from flask_restful import ( # type: ignore
Resource, # type: ignore Resource, # type: ignore
@ -15,76 +14,101 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import login_required from libs.login import login_required
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider from models.oauth import DatasourceOauthParamConfig
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourcePluginOauthApi(Resource): class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self, provider: str):
parser = reqparse.RequestParser() user = current_user
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") tenant_id = user.current_tenant_id
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# Check user role first
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = ( provider_id = DatasourceProviderID(provider)
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
)
if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider}")
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
) )
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
redirect_url = ( redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback"
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}" oauth_client_params = oauth_config.system_credentials
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
) )
system_credentials = plugin_oauth_config.system_credentials response = make_response(jsonable_encoder(authorization_url_response))
if system_credentials: response.set_cookie(
system_credentials["redirect_url"] = redirect_url "context_id",
response = oauth_handler.get_authorization_url( context_id,
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
) )
return response.model_dump() return response
class DatasourceOauthCallback(Resource): class DatasourceOAuthCallback(Resource):
@setup_required @setup_required
@login_required def get(self, provider: str):
@account_initialization_required if not current_user.is_editor:
def get(self): raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") context_id = request.cookies.get("context_id")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") if not context_id:
args = parser.parse_args() raise Forbidden("context_id not found")
provider = args["provider"]
plugin_id = args["plugin_id"] context = OAuthProxyService.use_proxy_context(context_id)
oauth_handler = OAuthHandler() if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
plugin_oauth_config = ( plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
) )
if not plugin_oauth_config: if not plugin_oauth_config:
raise NotFound() raise NotFound()
credentials = oauth_handler.get_credentials( redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback"
current_user.current_tenant.id, oauth_handler = OAuthHandler()
current_user.id, oauth_response = oauth_handler.get_credentials(
plugin_id, tenant_id=tenant_id,
provider, user_id=user_id,
plugin_id=plugin_id,
provider=provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=plugin_oauth_config.system_credentials, system_credentials=plugin_oauth_config.system_credentials,
request=request, request=request,
) )
datasource_provider = DatasourceProvider( datasource_provider_service = DatasourceProviderService()
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=provider_id,
credentials=dict(oauth_response.credentials),
name=None,
) )
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}") return redirect(f"{dify_config.CONSOLE_WEB_URL}")
@ -92,26 +116,23 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self): def post(self, provider: str):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None)
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
provider_id = DatasourceProviderID(provider)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
try: try:
datasource_provider_service.datasource_provider_credentials_validate( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=args["provider"], provider_id=provider_id,
plugin_id=args["plugin_id"],
credentials=args["credentials"], credentials=args["credentials"],
name="test" + str(random.randint(1, 1000000)), # noqa: S311 name=args["name"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -121,14 +142,13 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self): def get(self, provider: str):
parser = reqparse.RequestParser() provider_id = DatasourceProviderID(provider)
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials( datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"] tenant_id=current_user.current_tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
@ -137,29 +157,27 @@ class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, auth_id: str): def delete(self, provider: str, auth_id: str):
parser = reqparse.RequestParser() provider_id = DatasourceProviderID(provider)
parser.add_argument("provider", type=str, required=True, nullable=False, location="args") plugin_id = provider_id.plugin_id
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") provider_name = provider_id.provider_name
args = parser.parse_args()
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=auth_id,
provider=args["provider"], provider=provider_name,
plugin_id=args["plugin_id"], plugin_id=plugin_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, auth_id: str): def patch(self, provider: str, auth_id: str):
provider_id = DatasourceProviderID(provider)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.is_editor: if not current_user.is_editor:
@ -169,8 +187,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=auth_id,
provider=args["provider"], provider=provider_id.provider_name,
plugin_id=args["plugin_id"], plugin_id=provider_id.plugin_id,
credentials=args["credentials"], credentials=args["credentials"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
@ -193,21 +211,21 @@ class DatasourceAuthListApi(Resource):
# Import Rag Pipeline # Import Rag Pipeline
api.add_resource( api.add_resource(
DatasourcePluginOauthApi, DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/datasource", "/oauth/plugin/<path:provider>/datasource/get-authorization-url",
) )
api.add_resource( api.add_resource(
DatasourceOauthCallback, DatasourceOAuthCallback,
"/oauth/plugin/datasource/callback", "/oauth/plugin/<path:provider>/datasource/callback",
) )
api.add_resource( api.add_resource(
DatasourceAuth, DatasourceAuth,
"/auth/plugin/datasource", "/auth/plugin/datasource/<path:provider>",
) )
api.add_resource( api.add_resource(
DatasourceAuthUpdateDeleteApi, DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<string:auth_id>", "/auth/plugin/datasource/<path:provider>/<string:auth_id>",
) )
api.add_resource( api.add_resource(

@ -0,0 +1,35 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any],
credential_type: CredentialType,
fallback_context: str = "provider"
) -> str:
try:
default_pattern = f"{credential_type.get_name()}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {fallback_context}: {str(e)}")
return f"{credential_type.get_name()} 1"

@ -1,13 +1,18 @@
import logging import logging
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.helper.provider_name_generator import generate_provider_name
from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceProvider from models.oauth import DatasourceProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -21,8 +26,71 @@ class DatasourceProviderService:
def __init__(self) -> None: def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager() self.provider_manager = PluginDatasourceManager()
def datasource_provider_credentials_validate( @staticmethod
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.all()
)
return generate_provider_name(db_providers, credential_type, f"datasource provider {provider_id}")
def add_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
raise ValueError("name is already exists")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None: ) -> None:
""" """
validate datasource provider credentials. validate datasource provider credentials.
@ -31,45 +99,49 @@ class DatasourceProviderService:
:param provider: :param provider:
:param credentials: :param credentials:
""" """
# check name is exist provider_name = provider_id.provider_name
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first() plugin_id = provider_id.plugin_id
if datasource_provider: with Session(db.engine) as session:
raise ValueError("Authorization name is already exists") lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=CredentialType.API_KEY,
)
credential_valid = self.provider_manager.validate_provider_credentials( # check name is exist
tenant_id=tenant_id, if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
user_id=current_user.id, raise ValueError("Authorization name is already exists")
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
.first()
)
provider_credential_secret_variables = self.extract_secret_variables( credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" tenant_id=tenant_id,
) user_id=current_user.id,
for key, value in credentials.items(): provider=provider_name,
if key in provider_credential_secret_variables: plugin_id=plugin_id,
# if send [__HIDDEN__] in secret input, it will be same as original value credentials=credentials,
credentials[key] = encrypter.encrypt_token(tenant_id, value) )
datasource_provider = DatasourceProvider( if credential_valid:
tenant_id=tenant_id, provider_credential_secret_variables = self.extract_secret_variables(
name=name, tenant_id=tenant_id, provider_id=f"{provider_id}"
provider=provider, )
plugin_id=plugin_id, for key, value in credentials.items():
auth_type="api_key", if key in provider_credential_secret_variables:
encrypted_credentials=credentials, # if send [__HIDDEN__] in secret input, it will be same as original value
) credentials[key] = encrypter.encrypt_token(tenant_id, value)
db.session.add(datasource_provider) datasource_provider = DatasourceProvider(
db.session.commit() tenant_id=tenant_id,
else: name=db_provider_name,
raise CredentialsValidateFailedError() provider=provider_name,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
raise CredentialsValidateFailedError()
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
""" """

@ -1,6 +1,5 @@
import json import json
import logging import logging
import re
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@ -11,6 +10,7 @@ from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.provider_name_generator import generate_provider_name
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -299,42 +299,18 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name( def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str: ) -> str:
try: db_providers = (
db_providers = ( session.query(BuiltinToolProvider)
session.query(BuiltinToolProvider) .filter_by(
.filter_by( tenant_id=tenant_id,
tenant_id=tenant_id, provider=provider,
provider=provider, credential_type=credential_type.value,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
) )
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
# Get the default name pattern return generate_provider_name(db_providers, credential_type, f"builtin tool provider {provider}")
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
@staticmethod @staticmethod
def get_builtin_tool_provider_credentials( def get_builtin_tool_provider_credentials(

Loading…
Cancel
Save