feat(oauth): rename ToolProviderCredentialType to CredentialType for consistency

pull/22036/head
Harry 10 months ago
parent 26b46b88c9
commit 9f053f3bbc

@ -19,7 +19,7 @@ from controllers.console.wraps import (
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value from libs.helper import alphanumeric, uuid_value
from libs.login import login_required from libs.login import login_required
@ -122,7 +122,7 @@ class ToolBuiltinProviderAddApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("type", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if args["type"] not in ToolProviderCredentialType.values(): if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}") raise ValueError(f"Invalid credential type: {args['type']}")
return BuiltinToolManageService.add_builtin_tool_provider( return BuiltinToolManageService.add_builtin_tool_provider(
@ -131,7 +131,7 @@ class ToolBuiltinProviderAddApi(Resource):
provider=provider, provider=provider,
credentials=args["credentials"], credentials=args["credentials"],
name=args["name"], name=args["name"],
api_type=ToolProviderCredentialType.of(args["type"]), api_type=CredentialType.of(args["type"]),
) )
@ -378,7 +378,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
return jsonable_encoder( return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema( BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, ToolProviderCredentialType.of(credential_type), tenant_id provider, CredentialType.of(credential_type), tenant_id
) )
) )
@ -747,7 +747,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
api_type=ToolProviderCredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth/plugin/{provider}/tool/success")

@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient): class PluginToolManager(BasePluginClient):
@ -78,7 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str, tool_provider: str,
tool_name: str, tool_name: str,
credentials: dict[str, Any], credentials: dict[str, Any],
credential_type: ToolProviderCredentialType, credential_type: CredentialType,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,

@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel): class ToolRuntime(BaseModel):
@ -17,7 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict) credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY credential_type: Optional[CredentialType] = CredentialType.API_KEY
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -8,9 +8,9 @@ from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema, OAuthSchema,
ToolEntity, ToolEntity,
ToolProviderCredentialType,
ToolProviderEntity, ToolProviderEntity,
ToolProviderType, ToolProviderType,
) )
@ -111,7 +111,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema :return: the credentials schema
""" """
return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value) return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
""" """
@ -120,9 +120,9 @@ class BuiltinToolProviderController(ToolProviderController):
:param credential_type: the type of the credential :param credential_type: the type of the credential
:return: the credentials schema of the provider :return: the credentials schema of the provider
""" """
if credential_type == ToolProviderCredentialType.OAUTH2.value: if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == ToolProviderCredentialType.API_KEY.value: if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}") raise ValueError(f"Invalid credential type: {credential_type}")
@ -140,9 +140,9 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
types = [] types = []
if self.entity.credentials_schema is not None: if self.entity.credentials_schema is not None:
types.append(ToolProviderCredentialType.API_KEY.value) types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None: if self.entity.oauth_schema is not None:
types.append(ToolProviderCredentialType.OAUTH2.value) types.append(CredentialType.OAUTH2.value)
return types return types
def get_tools(self) -> list[BuiltinTool]: def get_tools(self) -> list[BuiltinTool]:

@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderCredentialType, ToolProviderType from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel): class ToolApiEntity(BaseModel):
@ -76,7 +76,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential") id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential") name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential") provider: str = Field(description="The provider of the credential")
credential_type: ToolProviderCredentialType = Field(description="The type of the credential") credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field( is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace" default=False, description="Whether the credential is the default credential for the provider in the workspace"
) )

@ -445,30 +445,30 @@ class ToolSelector(BaseModel):
return self.model_dump() return self.model_dump()
class ToolProviderCredentialType(enum.StrEnum): class CredentialType(enum.StrEnum):
API_KEY = "api-key" API_KEY = "api-key"
OAUTH2 = "oauth2" OAUTH2 = "oauth2"
def get_name(self): def get_name(self):
if self == ToolProviderCredentialType.API_KEY: if self == CredentialType.API_KEY:
return "API KEY" return "API KEY"
elif self == ToolProviderCredentialType.OAUTH2: elif self == CredentialType.OAUTH2:
return "AUTH" return "AUTH"
else: else:
return self.value.replace("-", " ").upper() return self.value.replace("-", " ").upper()
def is_editable(self): def is_editable(self):
return self == ToolProviderCredentialType.API_KEY return self == CredentialType.API_KEY
def is_validate_allowed(self): def is_validate_allowed(self):
return self == ToolProviderCredentialType.API_KEY return self == CredentialType.API_KEY
@classmethod @classmethod
def values(cls): def values(cls):
return [item.value for item in cls] return [item.value for item in cls]
@classmethod @classmethod
def of(cls, credential_type: str) -> "ToolProviderCredentialType": def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower() type_name = credential_type.lower()
if type_name == "api-key": if type_name == "api-key":
return cls.API_KEY return cls.API_KEY

@ -37,9 +37,9 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
CredentialType,
ToolInvokeFrom, ToolInvokeFrom,
ToolParameter, ToolParameter,
ToolProviderCredentialType,
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
@ -240,7 +240,7 @@ class ToolManager:
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials), credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type), credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,

@ -19,7 +19,7 @@ from core.tools.entities.api_entities import (
ToolProviderCredentialApiEntity, ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity, ToolProviderCredentialInfoApiEntity,
) )
from core.tools.entities.tool_entities import ToolProviderCredentialType from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_label_manager import ToolLabelManager
@ -96,7 +96,7 @@ class BuiltinToolManageService:
@staticmethod @staticmethod
def list_builtin_provider_credentials_schema( def list_builtin_provider_credentials_schema(
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str provider_name: str, credential_type: CredentialType, tenant_id: str
): ):
""" """
list builtin provider credentials schema list builtin provider credentials schema
@ -123,7 +123,7 @@ class BuiltinToolManageService:
raise ValueError(f"you have not added provider {provider}") raise ValueError(f"you have not added provider {provider}")
try: try:
if ToolProviderCredentialType.of(db_provider.credential_type).is_editable(): if CredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials: if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials") raise ValueError(f"provider {provider} does not need credentials")
@ -166,7 +166,7 @@ class BuiltinToolManageService:
@staticmethod @staticmethod
def add_builtin_tool_provider( def add_builtin_tool_provider(
user_id: str, user_id: str,
api_type: ToolProviderCredentialType, api_type: CredentialType,
tenant_id: str, tenant_id: str,
provider: str, provider: str,
credentials: dict, credentials: dict,
@ -237,7 +237,7 @@ class BuiltinToolManageService:
@staticmethod @staticmethod
def generate_builtin_tool_provider_name( def generate_builtin_tool_provider_name(
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType tenant_id: str, provider: str, credential_type: CredentialType
) -> str: ) -> str:
try: try:
db_providers = ( db_providers = (

@ -15,8 +15,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ( from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
CredentialType,
ToolParameter, ToolParameter,
ToolProviderCredentialType,
ToolProviderType, ToolProviderType,
) )
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
@ -113,9 +113,9 @@ class ToolTransformService:
schema = { schema = {
x.to_basic_provider_config().name: x x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type( for x in provider_controller.get_credentials_schema_by_type(
ToolProviderCredentialType.of(db_provider.credential_type) CredentialType.of(db_provider.credential_type)
if db_provider if db_provider
else ToolProviderCredentialType.API_KEY else CredentialType.API_KEY
) )
} }
@ -139,7 +139,7 @@ class ToolTransformService:
config=[ config=[
x.to_basic_provider_config() x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type( for x in provider_controller.get_credentials_schema_by_type(
ToolProviderCredentialType.of(db_provider.credential_type) CredentialType.of(db_provider.credential_type)
) )
], ],
cache=ToolProviderCredentialsCache( cache=ToolProviderCredentialsCache(
@ -329,7 +329,7 @@ class ToolTransformService:
id=provider.id, id=provider.id,
name=provider.name, name=provider.name,
provider=provider.provider, provider=provider.provider,
credential_type=ToolProviderCredentialType.of(provider.credential_type), credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default, is_default=provider.is_default,
credentials=credentials, credentials=credentials,
) )

Loading…
Cancel
Save