feat(oauth): add credential handling and context support for tool invocations

feat/tool-plugin-oauth
Harry 7 months ago
parent 8fc5ccab35
commit 7de3436e6b

@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider, provider=payload.provider,
tool_name=payload.tool, tool_name=payload.tool,
tool_parameters=payload.tool_parameters, tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id
), ),
) )

@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict) tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None plugin_unique_identifier: str | None = None
credential_id: str | None = None
class AgentPromptEntity(BaseModel): class AgentPromptEntity(BaseModel):

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC): class BaseAgentStrategy(ABC):
@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]: ) -> Generator[AgentInvokeMessage, None, None]:
""" """
Invoke the agent strategy. Invoke the agent strategy.
""" """
yield from self._invoke(params, user_id, conversation_id, app_id, message_id) yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]: def get_parameters(self) -> Sequence[AgentStrategyParameter]:
""" """
@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]: ) -> Generator[AgentInvokeMessage, None, None]:
pass pass

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]: ) -> Generator[AgentInvokeMessage, None, None]:
""" """
Invoke the agent strategy. Invoke the agent strategy.
@ -58,4 +60,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=app_id, app_id=app_id,
message_id=message_id, message_id=message_id,
context=PluginInvokeContext(
credentials=credentials or InvokeCredentials()
),
) )

@ -1,5 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str, provider: str,
tool_name: str, tool_name: str,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
""" """
invoke tool invoke tool
@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime # get tool runtime
try: try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin( tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
) )
response = ToolEngine.generic_invoke( response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
) )
class InvokeCredentials(BaseModel):
tool_credentials: dict[str, str] = Field(
default_factory=dict,
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
)
class PluginInvokeContext(BaseModel):
credentials: Optional[InvokeCredentials] = Field(
default_factory=InvokeCredentials,
description="Credentials context for the plugin invocation or backward invocation.",
)
class RequestInvokeTool(BaseModel): class RequestInvokeTool(BaseModel):
""" """
Request to invoke a tool Request to invoke a tool
@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
provider: str provider: str
tool: str tool: str
tool_parameters: dict tool_parameters: dict
credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel): class BaseRequestInvokeModel(BaseModel):

@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity, PluginAgentProviderEntity,
) )
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
app_id: Optional[str] = None, app_id: Optional[str] = None,
message_id: Optional[str] = None, message_id: Optional[str] = None,
context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]: ) -> Generator[AgentInvokeMessage, None, None]:
""" """
Invoke the agent with the given tenant, user, plugin, provider, name and parameters. Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id, "conversation_id": conversation_id,
"app_id": app_id, "app_id": app_id,
"message_id": message_id, "message_id": message_id,
"context": context.model_dump() if context else {},
"data": { "data": {
"agent_strategy_provider": agent_provider_id.provider_name, "agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy, "agent_strategy": agent_strategy,

@ -446,6 +446,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider") provider_id: str = Field(..., description="The id of the provider")
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool") tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool") tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")

@ -321,6 +321,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
invoke_from=invoke_from, invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT, tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
) )
runtime_parameters = {} runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters() parameters = tool_entity.get_merged_runtime_parameters()
@ -393,6 +394,7 @@ class ToolManager:
provider: str, provider: str,
tool_name: str, tool_name: str,
tool_parameters: dict[str, Any], tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Tool: ) -> Tool:
""" """
get tool runtime from plugin get tool runtime from plugin
@ -404,6 +406,7 @@ class ToolManager:
tenant_id=tenant_id, tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API, invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN, tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
) )
runtime_parameters = {} runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters() parameters = tool_entity.get_merged_runtime_parameters()

@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from packaging.version import Version from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -84,6 +91,7 @@ class AgentNode(ToolNode):
for_log=True, for_log=True,
strategy=strategy, strategy=strategy,
) )
credentials = self._generate_credentials(parameters=parameters)
# get conversation id # get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@ -94,6 +102,7 @@ class AgentNode(ToolNode):
user_id=self.user_id, user_id=self.user_id,
app_id=self.app_id, app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None, conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
) )
except Exception as e: except Exception as e:
yield RunCompletedEvent( yield RunCompletedEvent(
@ -246,6 +255,7 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""), tool_name=tool.get("tool_name", ""),
tool_parameters=parameters, tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None), plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
) )
extra = tool.get("extra", {}) extra = tool.get("extra", {})
@ -276,6 +286,7 @@ class AgentNode(ToolNode):
{ {
**tool_runtime.entity.model_dump(mode="json"), **tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters, "runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value, "provider_type": provider_type.value,
} }
) )
@ -305,6 +316,27 @@ class AgentNode(ToolNode):
return result return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,

Loading…
Cancel
Save