feat(oauth): refactor session management in tool provider operations

feat/tool-plugin-oauth
Harry 10 months ago
parent ef330fec2c
commit f35b8d6245

@ -7,6 +7,7 @@ from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
@ -114,9 +115,16 @@ class BuiltinToolManageService:
"""
update builtin tool provider
"""
with Session(db.engine) as session:
# get if the provider exists
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
@ -130,35 +138,41 @@ class BuiltinToolManageService:
tenant_id, db_provider, provider, provider_controller
)
# Decrypt and restore original credentials for masked values
original_credentials = encrypter.decrypt(db_provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for key, value in credentials.items():
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE)
for key, value in credentials.items()
}
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
provider_controller.validate_credentials(user_id, new_credentials)
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# update name if provided
if name is not None and db_provider.name != name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
db.session.commit()
session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
db.session.rollback()
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -175,15 +189,16 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
try:
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
@ -197,8 +212,17 @@ class BuiltinToolManageService:
# generate name if not provided
if name is None:
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id=tenant_id, provider=provider, credential_type=api_type
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
@ -219,15 +243,15 @@ class BuiltinToolManageService:
name=name,
)
db.session.add(db_provider)
db.session.commit()
session.add(db_provider)
session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
db.session.rollback()
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@ -249,10 +273,12 @@ class BuiltinToolManageService:
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
db.session.query(BuiltinToolProvider)
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
@ -308,7 +334,7 @@ class BuiltinToolManageService:
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
tenant_id, default_provider, default_provider.provider, provider_controller
)
@ -343,18 +369,26 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if tool_provider is None:
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
db.session.delete(tool_provider)
db.session.commit()
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, tool_provider, provider, provider_controller
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
@ -507,18 +541,6 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
return provider
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""

Loading…
Cancel
Save