refactor: credentials schemas to array

pull/9184/head
Yeuoly 2 years ago
parent c9f80b46a1
commit 6dfc31a542
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -159,3 +159,6 @@ class ProviderConfig(BasicProviderConfig):
help: Optional[I18nObject] = None help: Optional[I18nObject] = None
url: Optional[str] = None url: Optional[str] = None
placeholder: Optional[I18nObject] = None placeholder: Optional[I18nObject] = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)

@ -1,6 +1,3 @@
from collections.abc import Mapping
from typing import Any
from core.plugin.entities.request import RequestInvokeEncrypt from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import ProviderConfigEncrypter
from models.account import Tenant from models.account import Tenant
@ -11,7 +8,7 @@ class PluginEncrypter:
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter( encrypter = ProviderConfigEncrypter(
tenant_id=tenant.id, tenant_id=tenant.id,
config=payload.data, config=payload.config,
provider_type=payload.namespace, provider_type=payload.namespace,
provider_identity=payload.identity, provider_identity=payload.identity,
) )

@ -1,4 +1,3 @@
from collections.abc import Mapping
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -12,7 +11,7 @@ class EndpointDeclaration(BaseModel):
declaration of an endpoint declaration of an endpoint
""" """
settings: Mapping[str, ProviderConfig] = Field(default_factory=Mapping) settings: list[ProviderConfig] = Field(default_factory=list)
class EndpointEntity(BasePluginEntity): class EndpointEntity(BasePluginEntity):

@ -1,4 +1,3 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -181,4 +180,4 @@ class RequestInvokeEncrypt(BaseModel):
namespace: Literal["endpoint"] namespace: Literal["endpoint"]
identity: str identity: str
data: dict = Field(default_factory=dict) data: dict = Field(default_factory=dict)
config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping) config: list[BasicProviderConfig] = Field(default_factory=list)

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Any from typing import Any
from core.entities.provider_entities import ProviderConfig from core.entities.provider_entities import ProviderConfig
@ -16,13 +17,13 @@ class ToolProviderController(ABC):
def __init__(self, entity: ToolProviderEntity) -> None: def __init__(self, entity: ToolProviderEntity) -> None:
self.entity = entity self.entity = entity
def get_credentials_schema(self) -> dict[str, ProviderConfig]: def get_credentials_schema(self) -> list[ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider
:return: the credentials schema :return: the credentials schema
""" """
return self.entity.credentials_schema.copy() return deepcopy(self.entity.credentials_schema)
@abstractmethod @abstractmethod
def get_tool(self, tool_name: str) -> Tool: def get_tool(self, tool_name: str) -> Tool:
@ -48,10 +49,13 @@ class ToolProviderController(ABC):
:param credentials: the credentials of the tool :param credentials: the credentials of the tool
""" """
credentials_schema = self.entity.credentials_schema credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None: if credentials_schema is None:
return return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {} credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema: for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name] credentials_need_to_validate[credential_name] = credentials_schema[credential_name]

@ -34,10 +34,14 @@ class BuiltinToolProviderController(ToolProviderController):
for credential_name in provider_yaml["credentials_for_provider"]: for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
for credential in provider_yaml.get("credentials_for_provider", {}):
credentials_schema.append(credential)
super().__init__( super().__init__(
entity=ToolProviderEntity( entity=ToolProviderEntity(
identity=provider_yaml["identity"], identity=provider_yaml["identity"],
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {}, credentials_schema=credentials_schema,
), ),
) )
@ -84,14 +88,14 @@ class BuiltinToolProviderController(ToolProviderController):
self.tools = tools self.tools = tools
return tools return tools
def get_credentials_schema(self) -> dict[str, ProviderConfig]: def get_credentials_schema(self) -> list[ProviderConfig]:
""" """
returns the credentials schema of the provider returns the credentials schema of the provider
:return: the credentials schema :return: the credentials schema
""" """
if not self.entity.credentials_schema: if not self.entity.credentials_schema:
return {} return []
return self.entity.credentials_schema.copy() return self.entity.credentials_schema.copy()

@ -12,4 +12,3 @@ identity:
icon: icon.svg icon: icon.svg
tags: tags:
- productivity - productivity
credentials_for_provider:

@ -12,4 +12,3 @@ identity:
icon: icon.svg icon: icon.svg
tags: tags:
- utilities - utilities
credentials_for_provider:

@ -28,8 +28,8 @@ class ApiToolProviderController(ToolProviderController):
@classmethod @classmethod
def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType): def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
credentials_schema = { credentials_schema = [
"auth_type": ProviderConfig( ProviderConfig(
name="auth_type", name="auth_type",
required=True, required=True,
type=ProviderConfig.Type.SELECT, type=ProviderConfig.Type.SELECT,
@ -40,24 +40,24 @@ class ApiToolProviderController(ToolProviderController):
default="none", default="none",
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"), help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
) )
} ]
if auth_type == ApiProviderAuthType.API_KEY: if auth_type == ApiProviderAuthType.API_KEY:
credentials_schema = { credentials_schema = [
**credentials_schema, *credentials_schema,
"api_key_header": ProviderConfig( ProviderConfig(
name="api_key_header", name="api_key_header",
required=False, required=False,
default="api_key", default="api_key",
type=ProviderConfig.Type.TEXT_INPUT, type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"), help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
), ),
"api_key_value": ProviderConfig( ProviderConfig(
name="api_key_value", name="api_key_value",
required=True, required=True,
type=ProviderConfig.Type.SECRET_INPUT, type=ProviderConfig.Type.SECRET_INPUT,
help=I18nObject(en_US="The api key", zh_Hans="api key的值"), help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
), ),
"api_key_header_prefix": ProviderConfig( ProviderConfig(
name="api_key_header_prefix", name="api_key_header_prefix",
required=False, required=False,
default="basic", default="basic",
@ -69,7 +69,7 @@ class ApiToolProviderController(ToolProviderController):
ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")), ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
], ],
), ),
} ]
elif auth_type == ApiProviderAuthType.NONE: elif auth_type == ApiProviderAuthType.NONE:
pass pass

@ -2,7 +2,6 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -62,7 +61,3 @@ class ToolProviderApiEntity(BaseModel):
"tools": tools, "tools": tools,
"labels": self.labels, "labels": self.labels,
} }
class ToolProviderCredentialsApiEntity(BaseModel):
credentials: dict[str, ProviderConfig]

@ -312,7 +312,7 @@ class ToolEntity(BaseModel):
class ToolProviderEntity(BaseModel): class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity identity: ToolProviderIdentity
credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict) credentials_schema: list[ProviderConfig] = Field(default_factory=list)
class ToolProviderEntityWithPlugin(ToolProviderEntity): class ToolProviderEntityWithPlugin(ToolProviderEntity):

@ -160,7 +160,7 @@ class ToolManager:
credentials = builtin_provider.credentials credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
@ -186,7 +186,7 @@ class ToolManager:
# decrypt the credentials # decrypt the credentials
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=api_provider.get_credentials_schema(), config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value, provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name, provider_identity=api_provider.entity.identity.name,
) )
@ -643,7 +643,7 @@ class ToolManager:
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value, provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name, provider_identity=controller.entity.identity.name,
) )

@ -1,4 +1,3 @@
from collections.abc import Mapping
from copy import deepcopy from copy import deepcopy
from typing import Any from typing import Any
@ -17,7 +16,7 @@ from core.tools.entities.tool_entities import (
class ProviderConfigEncrypter(BaseModel): class ProviderConfigEncrypter(BaseModel):
tenant_id: str tenant_id: str
config: Mapping[str, BasicProviderConfig] config: list[BasicProviderConfig]
provider_type: str provider_type: str
provider_identity: str provider_identity: str
@ -36,7 +35,10 @@ class ProviderConfigEncrypter(BaseModel):
data = self._deep_copy(data) data = self._deep_copy(data)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.config fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data: if field_name in data:
@ -54,7 +56,10 @@ class ProviderConfigEncrypter(BaseModel):
data = self._deep_copy(data) data = self._deep_copy(data)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.config fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data: if field_name in data:
@ -83,7 +88,10 @@ class ProviderConfigEncrypter(BaseModel):
return cached_credentials return cached_credentials
data = self._deep_copy(data) data = self._deep_copy(data)
# get fields need to be decrypted # get fields need to be decrypted
fields = self.config fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items(): for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT: if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data: if field_name in data:

@ -35,7 +35,7 @@ class BuiltinToolManageService:
tool_provider_configurations = ProviderConfigEncrypter( tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
@ -78,7 +78,7 @@ class BuiltinToolManageService:
:return: the list of tool providers :return: the list of tool providers
""" """
provider = ToolManager.get_builtin_provider(provider_name, tenant_id) provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()]) return jsonable_encoder(provider.get_credentials_schema())
@staticmethod @staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
@ -102,7 +102,7 @@ class BuiltinToolManageService:
raise ValueError(f"provider {provider_name} does not need credentials") raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
@ -164,7 +164,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
@ -196,7 +196,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )

@ -85,7 +85,8 @@ class ToolTransformService:
) )
# get credentials schema # get credentials schema
schema = provider_controller.get_credentials_schema() schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
for name, value in schema.items(): for name, value in schema.items():
if result.masked_credentials: if result.masked_credentials:
result.masked_credentials[name] = "" result.masked_credentials[name] = ""
@ -103,7 +104,7 @@ class ToolTransformService:
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )
@ -208,7 +209,7 @@ class ToolTransformService:
# init tool configuration # init tool configuration
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=db_provider.tenant_id, tenant_id=db_provider.tenant_id,
config=provider_controller.get_credentials_schema(), config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value, provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name, provider_identity=provider_controller.entity.identity.name,
) )

Loading…
Cancel
Save