feat(oauth): refactor session management in tool provider operations

feat/tool-plugin-oauth
Harry 11 months ago
parent ef330fec2c
commit f35b8d6245

@ -7,6 +7,7 @@ from typing import Any, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from constants import HIDDEN_VALUE
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
@ -114,9 +115,16 @@ class BuiltinToolManageService:
""" """
update builtin tool provider update builtin tool provider
""" """
with Session(db.engine) as session:
# get if the provider exists # get if the provider exists
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None: if db_provider is None:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider}")
@ -130,35 +138,41 @@ class BuiltinToolManageService:
tenant_id, db_provider, provider, provider_controller tenant_id, db_provider, provider, provider_controller
) )
# Decrypt and restore original credentials for masked values
original_credentials = encrypter.decrypt(db_provider.credentials) original_credentials = encrypter.decrypt(db_provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials) new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE)
# check if the credential has changed, save the original credential for key, value in credentials.items()
for key, value in credentials.items(): }
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
if CredentialType.of(db_provider.credential_type).is_validate_allowed(): if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials) provider_controller.validate_credentials(user_id, new_credentials)
# encrypt credentials # encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete() cache.delete()
# update name if provided # update name if provided
if name is not None and db_provider.name != name: if name is not None and db_provider.name != name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name db_provider.name = name
db.session.commit() session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
db.session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -175,15 +189,16 @@ class BuiltinToolManageService:
""" """
add builtin tool provider add builtin tool provider
""" """
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
try: try:
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20): with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials") raise ValueError(f"provider {provider} does not need credentials")
provider_count = ( provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
) )
# check if the provider count is reached the limit # check if the provider count is reached the limit
@ -197,8 +212,17 @@ class BuiltinToolManageService:
# generate name if not provided # generate name if not provided
if name is None: if name is None:
name = BuiltinToolManageService.generate_builtin_tool_provider_name( name = BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
) )
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter # create encrypter
encrypter, _ = create_provider_encrypter( encrypter, _ = create_provider_encrypter(
@ -219,15 +243,15 @@ class BuiltinToolManageService:
name=name, name=name,
) )
db.session.add(db_provider) session.add(db_provider)
db.session.commit() session.commit()
except ( except (
PluginDaemonClientSideError, PluginDaemonClientSideError,
ToolProviderNotFoundError, ToolProviderNotFoundError,
ToolNotFoundError, ToolNotFoundError,
ToolProviderCredentialValidationError, ToolProviderCredentialValidationError,
) as e: ) as e:
db.session.rollback() session.rollback()
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"} return {"result": "success"}
@ -249,10 +273,12 @@ class BuiltinToolManageService:
return encrypter, cache return encrypter, cache
@staticmethod @staticmethod
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str: def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try: try:
db_providers = ( db_providers = (
db.session.query(BuiltinToolProvider) session.query(BuiltinToolProvider)
.filter_by( .filter_by(
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
@ -308,7 +334,7 @@ class BuiltinToolManageService:
default_provider = providers[0] default_provider = providers[0]
default_provider.is_default = True default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter( encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, default_provider, default_provider.provider, provider_controller tenant_id, default_provider, default_provider.provider, provider_controller
) )
@ -343,18 +369,26 @@ class BuiltinToolManageService:
""" """
delete tool provider delete tool provider
""" """
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if tool_provider is None: if db_provider is None:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider}")
db.session.delete(tool_provider) session.delete(db_provider)
db.session.commit() session.commit()
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter( _, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, tool_provider, provider, provider_controller tenant_id, db_provider, provider, provider_controller
) )
cache.delete() cache.delete()
@ -507,18 +541,6 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
return provider
@staticmethod @staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
""" """

Loading…
Cancel
Save