feat: update provider parameter naming and refactor related logic in datasource_auth.py

feat/rag-2
Harry 10 months ago
parent 0ac5c0bf3e
commit 633bfc25e0

@ -27,26 +27,26 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider_id: str):
user = current_user user = current_user
tenant_id = user.current_tenant_id tenant_id = user.current_tenant_id
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = provider_id.provider_name provider_name = datasource_provider_id.provider_name
plugin_id = provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
oauth_config = ( oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first() db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
) )
if not oauth_config: if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider}") raise ValueError(f"No OAuth Client Config for {provider_id}")
context_id = OAuthProxyService.create_proxy_context( context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
) )
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback" redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_client_params = oauth_config.system_credentials oauth_client_params = oauth_config.system_credentials
authorization_url_response = oauth_handler.get_authorization_url( authorization_url_response = oauth_handler.get_authorization_url(
@ -70,7 +70,7 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
class DatasourceOAuthCallback(Resource): class DatasourceOAuthCallback(Resource):
@setup_required @setup_required
def get(self, provider: str): def get(self, provider_id: str):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -83,21 +83,21 @@ class DatasourceOAuthCallback(Resource):
raise Forbidden("Invalid context_id") raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id") user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = provider_id.provider_name provider_name = datasource_provider_id.provider_name
plugin_id = provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
plugin_oauth_config = ( plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first() db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
) )
if not plugin_oauth_config: if not plugin_oauth_config:
raise NotFound() raise NotFound()
redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback" redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider_id}/datasource/callback"
oauth_handler = OAuthHandler() oauth_handler = OAuthHandler()
oauth_response = oauth_handler.get_credentials( oauth_response = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
provider=provider_id.provider_name, provider=datasource_provider_id.provider_name,
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=plugin_oauth_config.system_credentials, system_credentials=plugin_oauth_config.system_credentials,
request=request, request=request,
@ -105,7 +105,7 @@ class DatasourceOAuthCallback(Resource):
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.add_datasource_oauth_provider( datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id, tenant_id=tenant_id,
provider_id=provider_id, provider_id=datasource_provider_id,
credentials=dict(oauth_response.credentials), credentials=dict(oauth_response.credentials),
name=None, name=None,
) )
@ -116,7 +116,7 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider_id: str):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -124,13 +124,13 @@ class DatasourceAuth(Resource):
parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None) parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
try: try:
datasource_provider_service.add_datasource_api_key_provider( datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider_id=provider_id, provider_id=datasource_provider_id,
credentials=args["credentials"], credentials=args["credentials"],
name=args["name"], name=args["name"],
) )
@ -142,13 +142,13 @@ class DatasourceAuth(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def get(self, provider: str): def get(self, provider_id: str):
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials( datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
provider=provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
@ -157,10 +157,10 @@ class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider: str, auth_id: str): def delete(self, provider_id: str, auth_id: str):
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
@ -175,8 +175,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider: str, auth_id: str): def patch(self, provider_id: str, auth_id: str):
provider_id = DatasourceProviderID(provider) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
@ -187,8 +187,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=auth_id,
provider=provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
credentials=args["credentials"], credentials=args["credentials"],
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
@ -212,20 +212,20 @@ class DatasourceAuthListApi(Resource):
# Import Rag Pipeline # Import Rag Pipeline
api.add_resource( api.add_resource(
DatasourcePluginOAuthAuthorizationUrl, DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider>/datasource/get-authorization-url", "/oauth/plugin/<path:provider_id>/datasource/get-authorization-url",
) )
api.add_resource( api.add_resource(
DatasourceOAuthCallback, DatasourceOAuthCallback,
"/oauth/plugin/<path:provider>/datasource/callback", "/oauth/plugin/<path:provider_id>/datasource/callback",
) )
api.add_resource( api.add_resource(
DatasourceAuth, DatasourceAuth,
"/auth/plugin/datasource/<path:provider>", "/auth/plugin/datasource/<path:provider_id>",
) )
api.add_resource( api.add_resource(
DatasourceAuthUpdateDeleteApi, DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<path:provider>/<string:auth_id>", "/auth/plugin/datasource/<path:provider_id>/<string:auth_id>",
) )
api.add_resource( api.add_resource(

Loading…
Cancel
Save