feat(oauth): add multi credentials support

pull/22036/head
Harry 10 months ago
parent b316867bab
commit 26b46b88c9

@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderCredentialType
class PluginToolManager(BasePluginClient): class PluginToolManager(BasePluginClient):
@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str, tool_provider: str,
tool_name: str, tool_name: str,
credentials: dict[str, Any], credentials: dict[str, Any],
credential_type: ToolProviderCredentialType,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name, "provider": tool_provider_id.provider_name,
"tool": tool_name, "tool": tool_name,
"credentials": credentials, "credentials": credentials,
"credential_type": credential_type,
"tool_parameters": tool_parameters, "tool_parameters": tool_parameters,
}, },
}, },

@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom from core.tools.entities.tool_entities import ToolInvokeFrom, ToolProviderCredentialType
class ToolRuntime(BaseModel): class ToolRuntime(BaseModel):
@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict) credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: Optional[ToolProviderCredentialType] = ToolProviderCredentialType.API_KEY
runtime_parameters: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider, tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name, tool_name=self.entity.identity.name,
credentials=self.runtime.credentials, credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters, tool_parameters=tool_parameters,
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=app_id, app_id=app_id,

@ -4,7 +4,7 @@ import mimetypes
from collections.abc import Generator from collections.abc import Generator
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from yarl import URL from yarl import URL
@ -39,6 +39,7 @@ from core.tools.entities.tool_entities import (
ApiProviderAuthType, ApiProviderAuthType,
ToolInvokeFrom, ToolInvokeFrom,
ToolParameter, ToolParameter,
ToolProviderCredentialType,
ToolProviderType, ToolProviderType,
) )
from core.tools.errors import ToolProviderNotFoundError from core.tools.errors import ToolProviderNotFoundError
@ -148,6 +149,7 @@ class ToolManager:
tenant_id: str, tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
""" """
get the tool runtime get the tool runtime
@ -158,6 +160,7 @@ class ToolManager:
:param tenant_id: the tenant id :param tenant_id: the tenant id
:param invoke_from: invoke from :param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from :param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool :return: the tool
""" """
@ -185,19 +188,31 @@ class ToolManager:
if isinstance(provider_controller, PluginToolProviderController): if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id) provider_id_entity = ToolProviderID(provider_id)
# get credentials # get credentials
builtin_provider: BuiltinToolProvider | None = ( if credential_id:
db.session.query(BuiltinToolProvider) builtin_provider = (
.filter( db.session.query(BuiltinToolProvider)
BuiltinToolProvider.tenant_id == tenant_id, .filter(
(BuiltinToolProvider.provider == str(provider_id_entity)) BuiltinToolProvider.tenant_id == tenant_id,
| (BuiltinToolProvider.provider == provider_id_entity.provider_name), BuiltinToolProvider.id == credential_id,
)
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {credential_id} not found")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
) )
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
else: else:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
@ -209,8 +224,6 @@ class ToolManager:
if builtin_provider is None: if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
encrypter, _ = create_encrypter( encrypter, _ = create_encrypter(
tenant_id=tenant_id, tenant_id=tenant_id,
config=[ config=[
@ -221,15 +234,13 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
), ),
) )
decrypted_credentials = encrypter.decrypt(credentials)
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=decrypted_credentials, credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=ToolProviderCredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from, tool_invoke_from=tool_invoke_from,
@ -362,6 +373,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW, tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
) )
runtime_parameters = {} runtime_parameters = {}
parameters = tool_runtime.get_merged_runtime_parameters() parameters = tool_runtime.get_merged_runtime_parameters()

@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str tool_name: str
tool_label: str # redundancy tool_label: str # redundancy
tool_configurations: dict[str, Any] tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before") @field_validator("tool_configurations", mode="before")

@ -582,6 +582,11 @@ class AppDslService:
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids for dataset_id in dataset_ids
] ]
# filter credential id from tool node
if node.get("data", {}).get("type", "") == NodeType.TOOL.value:
node["data"]["credential_id"] = None
export_data["workflow"] = workflow_dict export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow) dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [ export_data["dependencies"] = [

Loading…
Cancel
Save