diff --git a/api/core/datasource/datasource_tool/provider.py b/api/core/datasource/datasource_tool/provider.py index 3104728947..820224eeaa 100644 --- a/api/core/datasource/datasource_tool/provider.py +++ b/api/core/datasource/datasource_tool/provider.py @@ -78,3 +78,68 @@ class DatasourceToolProviderController(BuiltinToolProviderController): ) for datasource_entity in self.entity.datasources ] + + def validate_credentials_format(self, credentials: dict[str, Any]) -> None: + """ + validate the format of the credentials of the provider and set the default value if needed + + :param credentials: the credentials of the tool + """ + credentials_schema = dict[str, ProviderConfig]() + if credentials_schema is None: + return + + for credential in self.entity.credentials_schema: + credentials_schema[credential.name] = credential + + credentials_need_to_validate: dict[str, ProviderConfig] = {} + for credential_name in credentials_schema: + credentials_need_to_validate[credential_name] = credentials_schema[credential_name] + + for credential_name in credentials: + if credential_name not in credentials_need_to_validate: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} not found in provider {self.entity.identity.name}" + ) + + # check type + credential_schema = credentials_need_to_validate[credential_name] + if not credential_schema.required and credentials[credential_name] is None: + continue + + if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + elif credential_schema.type == ProviderConfig.Type.SELECT: + if not isinstance(credentials[credential_name], str): + raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") + + options = credential_schema.options + if not isinstance(options, list): + raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") + + if credentials[credential_name] not in [x.value for x in options]: + raise ToolProviderCredentialValidationError( + f"credential {credential_name} should be one of {options}" + ) + + credentials_need_to_validate.pop(credential_name) + + for credential_name in credentials_need_to_validate: + credential_schema = credentials_need_to_validate[credential_name] + if credential_schema.required: + raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") + + # the credential is not set currently, set the default value if needed + if credential_schema.default is not None: + default_value = credential_schema.default + # parse default value into the correct type + if credential_schema.type in { + ProviderConfig.Type.SECRET_INPUT, + ProviderConfig.Type.TEXT_INPUT, + ProviderConfig.Type.SELECT, + }: + default_value = str(default_value) + + credentials[credential_name] = default_value \ No newline at end of file diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/datasource_tool/tool.py index 1c8572c2c5..d55c28a9b9 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/datasource_tool/tool.py @@ -1,14 +1,16 @@ from collections.abc import Generator from typing import Any, Optional +from core.datasource.__base.datasource import Datasource from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceParameter, DatasourceProviderType +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceInvokeMessage, + DatasourceParameter, + DatasourceProviderType, +) from core.plugin.manager.datasource import PluginDatasourceManager -from core.plugin.manager.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format -from core.tools.__base.tool import Tool -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType class DatasourcePlugin(Datasource): @@ -16,11 +18,14 @@ class DatasourcePlugin(Datasource): icon: str plugin_unique_identifier: str runtime_parameters: Optional[list[DatasourceParameter]] + entity: DatasourceEntity + runtime: DatasourceRuntime def __init__( - self, entity: DatasourceEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str + self, entity: DatasourceEntity, runtime: DatasourceRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str ) -> None: - super().__init__(entity, runtime) + self.entity = entity + self.runtime = runtime self.tenant_id = tenant_id self.icon = icon self.plugin_unique_identifier = plugin_unique_identifier @@ -34,7 +39,7 @@ class DatasourcePlugin(Datasource): user_id: str, datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) @@ -54,7 +59,7 @@ class DatasourcePlugin(Datasource): user_id: str, datasource_parameters: dict[str, Any], rag_pipeline_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: manager = PluginDatasourceManager() datasource_parameters = convert_parameters_to_plugin_format(datasource_parameters) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 39c28c0d7d..de580b270e 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -105,7 +105,7 @@ class ApiProviderAuthType(Enum): raise ValueError(f"invalid mode value {value}") -class ToolInvokeMessage(BaseModel): +class DatasourceInvokeMessage(BaseModel): class TextMessage(BaseModel): text: str @@ -200,7 +200,7 @@ class ToolInvokeMessage(BaseModel): return v -class ToolInvokeMessageBinary(BaseModel): +class DatasourceInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") file_var: Optional[dict[str, Any]] = None