|
|
|
|
@ -7,6 +7,7 @@ from typing import Any, Optional
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from constants import HIDDEN_VALUE
|
|
|
|
|
from core.helper.position_helper import is_filtered
|
|
|
|
|
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
|
|
|
|
from core.plugin.entities.plugin import ToolProviderID
|
|
|
|
|
@ -114,52 +115,65 @@ class BuiltinToolManageService:
|
|
|
|
|
"""
|
|
|
|
|
update builtin tool provider
|
|
|
|
|
"""
|
|
|
|
|
# get if the provider exists
|
|
|
|
|
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
|
|
|
|
|
|
if db_provider is None:
|
|
|
|
|
raise ValueError(f"you have not added provider {provider}")
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
# get if the provider exists
|
|
|
|
|
db_provider = (
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
.filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
BuiltinToolProvider.id == credential_id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if db_provider is None:
|
|
|
|
|
raise ValueError(f"you have not added provider {provider}")
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if CredentialType.of(db_provider.credential_type).is_editable():
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
|
raise ValueError(f"provider {provider} does not need credentials")
|
|
|
|
|
try:
|
|
|
|
|
if CredentialType.of(db_provider.credential_type).is_editable():
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
|
raise ValueError(f"provider {provider} does not need credentials")
|
|
|
|
|
|
|
|
|
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
tenant_id, db_provider, provider, provider_controller
|
|
|
|
|
)
|
|
|
|
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
tenant_id, db_provider, provider, provider_controller
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Decrypt and restore original credentials for masked values
|
|
|
|
|
original_credentials = encrypter.decrypt(db_provider.credentials)
|
|
|
|
|
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
|
|
|
|
original_credentials = encrypter.decrypt(db_provider.credentials)
|
|
|
|
|
new_credentials: dict = {
|
|
|
|
|
key: value if value != HIDDEN_VALUE else original_credentials.get(key, HIDDEN_VALUE)
|
|
|
|
|
for key, value in credentials.items()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# check if the credential has changed, save the original credential
|
|
|
|
|
for key, value in credentials.items():
|
|
|
|
|
if key in masked_credentials and value == masked_credentials[key]:
|
|
|
|
|
credentials[key] = original_credentials[key]
|
|
|
|
|
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
|
|
|
|
|
provider_controller.validate_credentials(user_id, new_credentials)
|
|
|
|
|
|
|
|
|
|
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
|
|
|
|
|
provider_controller.validate_credentials(user_id, credentials)
|
|
|
|
|
# encrypt credentials
|
|
|
|
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
|
|
|
|
|
|
|
|
|
|
# encrypt credentials
|
|
|
|
|
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
|
|
|
|
cache.delete()
|
|
|
|
|
|
|
|
|
|
cache.delete()
|
|
|
|
|
# update name if provided
|
|
|
|
|
if name is not None and db_provider.name != name:
|
|
|
|
|
# check if the name is already used
|
|
|
|
|
if (
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
|
|
|
|
.count()
|
|
|
|
|
> 0
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(f"the credential name '{name}' is already used")
|
|
|
|
|
|
|
|
|
|
# update name if provided
|
|
|
|
|
if name is not None and db_provider.name != name:
|
|
|
|
|
db_provider.name = name
|
|
|
|
|
db_provider.name = name
|
|
|
|
|
|
|
|
|
|
db.session.commit()
|
|
|
|
|
except (
|
|
|
|
|
PluginDaemonClientSideError,
|
|
|
|
|
ToolProviderNotFoundError,
|
|
|
|
|
ToolNotFoundError,
|
|
|
|
|
ToolProviderCredentialValidationError,
|
|
|
|
|
) as e:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
raise ValueError(str(e))
|
|
|
|
|
session.commit()
|
|
|
|
|
except (
|
|
|
|
|
PluginDaemonClientSideError,
|
|
|
|
|
ToolProviderNotFoundError,
|
|
|
|
|
ToolNotFoundError,
|
|
|
|
|
ToolProviderCredentialValidationError,
|
|
|
|
|
) as e:
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise ValueError(str(e))
|
|
|
|
|
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
@ -175,59 +189,69 @@ class BuiltinToolManageService:
|
|
|
|
|
"""
|
|
|
|
|
add builtin tool provider
|
|
|
|
|
"""
|
|
|
|
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
|
|
|
|
try:
|
|
|
|
|
with redis_client.lock(lock, timeout=20):
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
|
raise ValueError(f"provider {provider} does not need credentials")
|
|
|
|
|
|
|
|
|
|
provider_count = (
|
|
|
|
|
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
|
|
|
|
)
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
|
|
|
|
with redis_client.lock(lock, timeout=20):
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
if not provider_controller.need_credentials:
|
|
|
|
|
raise ValueError(f"provider {provider} does not need credentials")
|
|
|
|
|
|
|
|
|
|
provider_count = (
|
|
|
|
|
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# check if the provider count is reached the limit
|
|
|
|
|
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
|
|
|
|
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
|
|
|
|
# check if the provider count is reached the limit
|
|
|
|
|
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
|
|
|
|
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
|
|
|
|
|
|
|
|
|
# validate credentials if allowed
|
|
|
|
|
if CredentialType.of(api_type).is_validate_allowed():
|
|
|
|
|
provider_controller.validate_credentials(user_id, credentials)
|
|
|
|
|
# validate credentials if allowed
|
|
|
|
|
if CredentialType.of(api_type).is_validate_allowed():
|
|
|
|
|
provider_controller.validate_credentials(user_id, credentials)
|
|
|
|
|
|
|
|
|
|
# generate name if not provided
|
|
|
|
|
if name is None:
|
|
|
|
|
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
|
|
|
|
|
tenant_id=tenant_id, provider=provider, credential_type=api_type
|
|
|
|
|
# generate name if not provided
|
|
|
|
|
if name is None:
|
|
|
|
|
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
|
|
|
|
|
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# check if the name is already used
|
|
|
|
|
if (
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
|
|
|
|
|
.count()
|
|
|
|
|
> 0
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(f"the credential name '{name}' is already used")
|
|
|
|
|
|
|
|
|
|
# create encrypter
|
|
|
|
|
encrypter, _ = create_provider_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[
|
|
|
|
|
x.to_basic_provider_config()
|
|
|
|
|
for x in provider_controller.get_credentials_schema_by_type(api_type)
|
|
|
|
|
],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# create encrypter
|
|
|
|
|
encrypter, _ = create_provider_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[
|
|
|
|
|
x.to_basic_provider_config()
|
|
|
|
|
for x in provider_controller.get_credentials_schema_by_type(api_type)
|
|
|
|
|
],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
db_provider = BuiltinToolProvider(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
|
|
|
|
credential_type=api_type.value,
|
|
|
|
|
name=name,
|
|
|
|
|
)
|
|
|
|
|
db_provider = BuiltinToolProvider(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
|
|
|
|
|
credential_type=api_type.value,
|
|
|
|
|
name=name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
db.session.add(db_provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
session.add(db_provider)
|
|
|
|
|
session.commit()
|
|
|
|
|
except (
|
|
|
|
|
PluginDaemonClientSideError,
|
|
|
|
|
ToolProviderNotFoundError,
|
|
|
|
|
ToolNotFoundError,
|
|
|
|
|
ToolProviderCredentialValidationError,
|
|
|
|
|
) as e:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise ValueError(str(e))
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
@ -249,10 +273,12 @@ class BuiltinToolManageService:
|
|
|
|
|
return encrypter, cache
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str:
|
|
|
|
|
def generate_builtin_tool_provider_name(
|
|
|
|
|
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
|
|
|
|
) -> str:
|
|
|
|
|
try:
|
|
|
|
|
db_providers = (
|
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
@ -308,7 +334,7 @@ class BuiltinToolManageService:
|
|
|
|
|
default_provider = providers[0]
|
|
|
|
|
default_provider.is_default = True
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
|
|
|
|
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
encrypter, _ = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
tenant_id, default_provider, default_provider.provider, provider_controller
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@ -343,20 +369,28 @@ class BuiltinToolManageService:
|
|
|
|
|
"""
|
|
|
|
|
delete tool provider
|
|
|
|
|
"""
|
|
|
|
|
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
db_provider = (
|
|
|
|
|
session.query(BuiltinToolProvider)
|
|
|
|
|
.filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
BuiltinToolProvider.id == credential_id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if tool_provider is None:
|
|
|
|
|
raise ValueError(f"you have not added provider {provider}")
|
|
|
|
|
if db_provider is None:
|
|
|
|
|
raise ValueError(f"you have not added provider {provider}")
|
|
|
|
|
|
|
|
|
|
db.session.delete(tool_provider)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
session.delete(db_provider)
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
# delete cache
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
tenant_id, tool_provider, provider, provider_controller
|
|
|
|
|
)
|
|
|
|
|
cache.delete()
|
|
|
|
|
# delete cache
|
|
|
|
|
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
|
|
|
|
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
|
|
|
|
tenant_id, db_provider, provider, provider_controller
|
|
|
|
|
)
|
|
|
|
|
cache.delete()
|
|
|
|
|
|
|
|
|
|
return {"result": "success"}
|
|
|
|
|
|
|
|
|
|
@ -507,18 +541,6 @@ class BuiltinToolManageService:
|
|
|
|
|
|
|
|
|
|
return BuiltinToolProviderSort.sort(result)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
|
provider: Optional[BuiltinToolProvider] = (
|
|
|
|
|
db.session.query(BuiltinToolProvider)
|
|
|
|
|
.filter(
|
|
|
|
|
BuiltinToolProvider.tenant_id == tenant_id,
|
|
|
|
|
BuiltinToolProvider.id == credential_id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
return provider
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
|
|
|
|
"""
|
|
|
|
|
|