feat(oauth): refactor session management in tool provider operations

feat/tool-plugin-oauth
Harry 10 months ago
parent ef330fec2c
commit f35b8d6245

@ -7,6 +7,7 @@ from typing import Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants import HIDDEN_VALUE
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
@ -114,52 +115,65 @@ class BuiltinToolManageService:
""" """
update builtin tool provider update builtin tool provider
""" """
# get if the provider exists with Session(db.engine) as session:
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) # get if the provider exists
db_provider = (
if db_provider is None: session.query(BuiltinToolProvider)
raise ValueError(f"you have not added provider {provider}") .filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try: try:
if CredentialType.of(db_provider.credential_type).is_editable(): if CredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials") raise ValueError(f"provider {provider} does not need credentials")
encrypter, cache = BuiltinToolManageService.create_tool_encrypter( encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller tenant_id, db_provider, provider, provider_controller
) )
# Decrypt and restore original credentials for masked values original_credentials = encrypter.decrypt(db_provider.credentials)
original_credentials = encrypter.decrypt(db_provider.credentials) new_credentials: dict = {
masked_credentials = encrypter.mask_tool_credentials(original_credentials) key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE)
for key, value in credentials.items()
}
# check if the credential has changed, save the original credential if CredentialType.of(db_provider.credential_type).is_validate_allowed():
for key, value in credentials.items(): provider_controller.validate_credentials(user_id, new_credentials)
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
if CredentialType.of(db_provider.credential_type).is_validate_allowed(): # encrypt credentials
provider_controller.validate_credentials(user_id, credentials) db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
# encrypt credentials cache.delete()
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete() # update name if provided
if name is not None and db_provider.name != name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# update name if provided db_provider.name = name
if name is not None and db_provider.name != name:
db_provider.name = name
db.session.commit() session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
db.session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -175,59 +189,69 @@ class BuiltinToolManageService:
""" """
add builtin tool provider add builtin tool provider
""" """
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
try: try:
with redis_client.lock(lock, timeout=20): with Session(db.engine) as session:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
if not provider_controller.need_credentials: with redis_client.lock(lock, timeout=20):
raise ValueError(f"provider {provider} does not need credentials") provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
provider_count = ( raise ValueError(f"provider {provider} does not need credentials")
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
) provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit # check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}") raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed # validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed(): if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials) provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided # generate name if not provided
if name is None: if name is None:
name = BuiltinToolManageService.generate_builtin_tool_provider_name( name = BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
) )
# create encrypter db_provider = BuiltinToolProvider(
encrypter, _ = create_provider_encrypter( tenant_id=tenant_id,
tenant_id=tenant_id, user_id=user_id,
config=[ provider=provider,
x.to_basic_provider_config() encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
for x in provider_controller.get_credentials_schema_by_type(api_type) credential_type=api_type.value,
], name=name,
cache=NoOpProviderCredentialCache(), )
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
db.session.add(db_provider) session.add(db_provider)
db.session.commit() session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
db.session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -249,10 +273,12 @@ class BuiltinToolManageService:
return encrypter, cache return encrypter, cache
@staticmethod @staticmethod
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str: def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try: try:
db_providers = ( db_providers = (
db.session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter_by( .filter_by(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
@ -308,7 +334,7 @@ class BuiltinToolManageService:
default_provider = providers[0] default_provider = providers[0]
default_provider.is_default = True default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter( encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, default_provider, default_provider.provider, provider_controller tenant_id, default_provider, default_provider.provider, provider_controller
) )
@ -343,20 +369,28 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if tool_provider is None: if db_provider is None:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider}")
db.session.delete(tool_provider) session.delete(db_provider)
db.session.commit() session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter( _, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, tool_provider, provider, provider_controller tenant_id, db_provider, provider, provider_controller
) )
cache.delete() cache.delete()
return {"result": "success"} return {"result": "success"}
@ -507,18 +541,6 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
return provider
@staticmethod @staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
""" """

Loading…
Cancel
Save