feat(oauth): add credential validation for providers

feat/tool-plugin-oauth
Harry 10 months ago
parent 0dc5bfb2c7
commit ef330fec2c

@ -95,9 +95,7 @@ class BuiltinToolManageService:
return entity return entity
@staticmethod @staticmethod
def list_builtin_provider_credentials_schema( def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
provider_name: str, credential_type: CredentialType, tenant_id: str
):
""" """
list builtin provider credentials schema list builtin provider credentials schema
@ -141,7 +139,8 @@ class BuiltinToolManageService:
if key in masked_credentials and value == masked_credentials[key]: if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key] credentials[key] = original_credentials[key]
provider_controller.validate_credentials(user_id, credentials) if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials # encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
@ -159,6 +158,7 @@ class BuiltinToolManageService:
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
db.session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -176,46 +176,59 @@ class BuiltinToolManageService:
add builtin tool provider add builtin tool provider
""" """
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20): try:
# check if the provider count is over the limit with redis_client.lock(lock, timeout=20):
provider_count = ( provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() if not provider_controller.need_credentials:
) raise ValueError(f"provider {provider} does not need credentials")
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}") provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
# TODO should we get name from oauth authentication?
name = (
name
if name
else BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type
) )
)
db_provider = BuiltinToolProvider( # check if the provider count is reached the limit
tenant_id=tenant_id, if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
user_id=user_id, raise ValueError(f"you have reached the maximum number of providers for {provider}")
provider=provider,
encrypted_credentials=json.dumps(credentials),
credential_type=api_type.value,
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # validate credentials if allowed
if not provider_controller.need_credentials: if CredentialType.of(api_type).is_validate_allowed():
raise ValueError(f"provider {provider} does not need credentials") provider_controller.validate_credentials(user_id, credentials)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter( # generate name if not provided
tenant_id, db_provider, provider, provider_controller if name is None:
) name = BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type
)
# encrypt credentials # create encrypter
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
cache.delete() db_provider = BuiltinToolProvider(
db.session.add(db_provider) tenant_id=tenant_id,
db.session.commit() user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
db.session.add(db_provider)
db.session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
db.session.rollback()
raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -236,9 +249,7 @@ class BuiltinToolManageService:
return encrypter, cache return encrypter, cache
@staticmethod @staticmethod
def generate_builtin_tool_provider_name( def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str:
tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try: try:
db_providers = ( db_providers = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
@ -324,7 +335,7 @@ class BuiltinToolManageService:
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials, credentials=credentials,
) )
return credential_info return credential_info
@staticmethod @staticmethod
@ -362,8 +373,8 @@ class BuiltinToolManageService:
# clear default provider # clear default provider
session.query(BuiltinToolProvider).filter_by( session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, default=True tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"default": False}) ).update({"is_default": False})
# set new default provider # set new default provider
target_provider.is_default = True target_provider.is_default = True

Loading…
Cancel
Save