feat: refactor datasource authentication APIs for improved credential management

feat/rag-2
Harry 10 months ago
parent 57b48f51b5
commit 17da96bdd8

@ -149,33 +149,40 @@ class DatasourceAuth(Resource):
) )
return {"result": datasources}, 200 return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, provider_id: str, auth_id: str): def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name provider_name = datasource_provider_id.provider_name
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials( datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=args["credential_id"],
provider=provider_name, provider=provider_name,
plugin_id=plugin_id, plugin_id=plugin_id,
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
class DatasourceAuthUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def patch(self, provider_id: str, auth_id: str): def post(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=str, required=False, nullable=True, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
@ -183,10 +190,11 @@ class DatasourceAuthUpdateDeleteApi(Resource):
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials( datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
auth_id=auth_id, auth_id=args["credential_id"],
provider=datasource_provider_id.provider_name, provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id, plugin_id=datasource_provider_id.plugin_id,
credentials=args["credentials"], credentials=args.get("credentials", {}),
name=args.get("name", None),
) )
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ValueError(str(ex)) raise ValueError(str(ex))
@ -228,6 +236,17 @@ class DatasourceAuthOauthCustomClient(Resource):
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id,
)
return {"result": "success"}, 200
class DatasourceAuthDefaultApi(Resource): class DatasourceAuthDefaultApi(Resource):
@setup_required @setup_required
@ -237,14 +256,14 @@ class DatasourceAuthDefaultApi(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService() datasource_provider_service = DatasourceProviderService()
datasource_provider_service.set_default_datasource_provider( datasource_provider_service.set_default_datasource_provider(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
datasource_provider_id=datasource_provider_id, datasource_provider_id=datasource_provider_id,
credential_id=args["credential_id"], credential_id=args["id"],
) )
return {"result": "success"}, 200 return {"result": "success"}, 200
@ -284,8 +303,13 @@ api.add_resource(
) )
api.add_resource( api.add_resource(
DatasourceAuthUpdateDeleteApi, DatasourceAuthUpdateApi,
"/auth/plugin/datasource/<path:provider_id>/<string:auth_id>", "/auth/plugin/datasource/<path:provider_id>/update",
)
api.add_resource(
DatasourceAuthDeleteApi,
"/auth/plugin/datasource/<path:provider_id>/delete",
) )
api.add_resource( api.add_resource(

@ -4,6 +4,7 @@ from typing import Any
from flask_login import current_user from flask_login import current_user
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name from core.helper.name_generator import generate_incremental_name
@ -29,6 +30,18 @@ class DatasourceProviderService:
def __init__(self) -> None: def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager() self.provider_manager = PluginDatasourceManager()
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]: def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
""" """
get default credentials get default credentials
@ -512,6 +525,10 @@ class DatasourceProviderService:
credentials = self.get_datasource_credentials( credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
) )
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/"
f"{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append( datasource_credentials.append(
{ {
"provider": datasource.provider, "provider": datasource.provider,
@ -542,6 +559,7 @@ class DatasourceProviderService:
tenant_id, datasource_provider_id tenant_id, datasource_provider_id
), ),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id), "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri
} }
if datasource.declaration.oauth_schema if datasource.declaration.oauth_schema
else None, else None,
@ -594,38 +612,50 @@ class DatasourceProviderService:
return copy_credentials_list return copy_credentials_list
def update_datasource_credentials( def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
) -> None: ) -> None:
""" """
update datasource credentials. update datasource credentials.
""" """
credential_valid = self.provider_manager.validate_provider_credentials( with Session(db.engine) as session:
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = ( datasource_provider = (
db.session.query(DatasourceProvider) session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first() .first()
) )
if not datasource_provider: if not datasource_provider:
raise ValueError("Datasource provider not found") raise ValueError("Datasource provider not found")
else: # update name
provider_credential_secret_variables = self.extract_secret_variables( if name and name != datasource_provider.name:
tenant_id=tenant_id, if (
provider_id=f"{plugin_id}/{provider}", session.query(DatasourceProvider)
credential_type=datasource_provider.auth_type, .filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
) .count()
> 0
):
raise ValueError("name is already exists")
datasource_provider.name = name
# update credentials
if credentials:
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
original_credentials = datasource_provider.encrypted_credentials original_credentials = datasource_provider.encrypted_credentials
for key, value in credentials.items(): for key, value in credentials.items():
if key in provider_credential_secret_variables: if key in self.extract_secret_variables(
# if send [__HIDDEN__] in secret input, it will be same as original value tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=datasource_provider.auth_type,
):
if value == HIDDEN_VALUE and key in original_credentials: if value == HIDDEN_VALUE and key in original_credentials:
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
credentials[key] = encrypter.encrypt_token(tenant_id, original_value) credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
@ -633,9 +663,7 @@ class DatasourceProviderService:
credentials[key] = encrypter.encrypt_token(tenant_id, value) credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider.encrypted_credentials = credentials datasource_provider.encrypted_credentials = credentials
db.session.commit() session.commit()
else:
raise CredentialsValidateFailedError()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
""" """

Loading…
Cancel
Save