refactor: tool
parent
3c1d32e3ac
commit
91cb80f795
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from openai import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRuntime(BaseModel):
|
||||||
|
"""
|
||||||
|
Meta data of a tool call processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
tool_id: Optional[str] = None
|
||||||
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
|
tool_invoke_from: Optional[ToolInvokeFrom] = None
|
||||||
|
credentials: Optional[dict[str, Any]] = None
|
||||||
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeToolRuntime(ToolRuntime):
|
||||||
|
"""
|
||||||
|
Fake tool runtime for testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
tenant_id="fake_tenant_id",
|
||||||
|
tool_id="fake_tool_id",
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||||
|
credentials={},
|
||||||
|
runtime_parameters={},
|
||||||
|
)
|
||||||
@ -1,13 +1,8 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
|
||||||
|
|
||||||
|
|
||||||
class QRCodeProvider(BuiltinToolProviderController):
|
class QRCodeProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
try:
|
pass
|
||||||
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolProviderCredentialValidationError(str(e))
|
|
||||||
|
|||||||
@ -1,16 +1,8 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||||
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
|
||||||
|
|
||||||
|
|
||||||
class WikiPediaProvider(BuiltinToolProviderController):
|
class WikiPediaProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
try:
|
pass
|
||||||
CurrentTimeTool().invoke(
|
|
||||||
user_id="",
|
|
||||||
tool_parameters={},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise ToolProviderCredentialValidationError(str(e))
|
|
||||||
|
|||||||
@ -1,207 +0,0 @@
|
|||||||
from collections.abc import Mapping
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|
||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.entities.common_entities import I18nObject
|
|
||||||
from core.tools.entities.tool_entities import (
|
|
||||||
ToolDescription,
|
|
||||||
ToolIdentity,
|
|
||||||
ToolParameter,
|
|
||||||
ToolParameterOption,
|
|
||||||
ToolProviderType,
|
|
||||||
)
|
|
||||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.model import App, AppMode
|
|
||||||
from models.tools import WorkflowToolProvider
|
|
||||||
from models.workflow import Workflow
|
|
||||||
|
|
||||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
|
||||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
|
||||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
|
||||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
|
||||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolProviderController(ToolProviderController):
|
|
||||||
provider_id: str
|
|
||||||
tools: list[WorkflowTool] = Field(default_factory=list)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, db_provider: WorkflowToolProvider) -> "WorkflowToolProviderController":
|
|
||||||
app = db_provider.app
|
|
||||||
|
|
||||||
if not app:
|
|
||||||
raise ValueError("app not found")
|
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
|
||||||
**{
|
|
||||||
"identity": {
|
|
||||||
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
|
||||||
"name": db_provider.label,
|
|
||||||
"label": {"en_US": db_provider.label, "zh_Hans": db_provider.label},
|
|
||||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
|
||||||
"icon": db_provider.icon,
|
|
||||||
},
|
|
||||||
"credentials_schema": {},
|
|
||||||
"provider_id": db_provider.id or "",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# init tools
|
|
||||||
|
|
||||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
|
||||||
|
|
||||||
return controller
|
|
||||||
|
|
||||||
@property
|
|
||||||
def provider_type(self) -> ToolProviderType:
|
|
||||||
return ToolProviderType.WORKFLOW
|
|
||||||
|
|
||||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
|
||||||
"""
|
|
||||||
get db provider tool
|
|
||||||
:param db_provider: the db provider
|
|
||||||
:param app: the app
|
|
||||||
:return: the tool
|
|
||||||
"""
|
|
||||||
workflow: Workflow | None = db.session.query(Workflow).filter(
|
|
||||||
Workflow.app_id == db_provider.app_id,
|
|
||||||
Workflow.version == db_provider.version
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not workflow:
|
|
||||||
raise ValueError("workflow not found")
|
|
||||||
|
|
||||||
# fetch start node
|
|
||||||
graph: Mapping = workflow.graph_dict
|
|
||||||
features_dict: Mapping = workflow.features_dict
|
|
||||||
features = WorkflowAppConfigManager.convert_features(
|
|
||||||
config_dict=features_dict,
|
|
||||||
app_mode=AppMode.WORKFLOW
|
|
||||||
)
|
|
||||||
|
|
||||||
parameters = db_provider.parameter_configurations
|
|
||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
|
||||||
|
|
||||||
user = db_provider.user
|
|
||||||
|
|
||||||
workflow_tool_parameters = []
|
|
||||||
for parameter in parameters:
|
|
||||||
variable = fetch_workflow_variable(parameter.name)
|
|
||||||
if variable:
|
|
||||||
parameter_type = None
|
|
||||||
options = []
|
|
||||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
|
||||||
raise ValueError(f"unsupported variable type {variable.type}")
|
|
||||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
|
||||||
|
|
||||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
|
||||||
options = [
|
|
||||||
ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
|
||||||
for option in variable.options
|
|
||||||
]
|
|
||||||
|
|
||||||
workflow_tool_parameters.append(
|
|
||||||
ToolParameter(
|
|
||||||
name=parameter.name,
|
|
||||||
label=I18nObject(en_US=variable.label, zh_Hans=variable.label),
|
|
||||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
|
||||||
type=parameter_type,
|
|
||||||
form=parameter.form,
|
|
||||||
llm_description=parameter.description,
|
|
||||||
required=variable.required,
|
|
||||||
options=options,
|
|
||||||
default=variable.default,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif features.file_upload:
|
|
||||||
workflow_tool_parameters.append(
|
|
||||||
ToolParameter(
|
|
||||||
name=parameter.name,
|
|
||||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
|
||||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
|
||||||
type=ToolParameter.ToolParameterType.FILE,
|
|
||||||
llm_description=parameter.description,
|
|
||||||
required=False,
|
|
||||||
form=parameter.form,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("variable not found")
|
|
||||||
|
|
||||||
return WorkflowTool(
|
|
||||||
identity=ToolIdentity(
|
|
||||||
author=user.name if user else "",
|
|
||||||
name=db_provider.name,
|
|
||||||
label=I18nObject(en_US=db_provider.label, zh_Hans=db_provider.label),
|
|
||||||
provider=self.provider_id,
|
|
||||||
icon=db_provider.icon,
|
|
||||||
),
|
|
||||||
description=ToolDescription(
|
|
||||||
human=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
|
||||||
llm=db_provider.description,
|
|
||||||
),
|
|
||||||
parameters=workflow_tool_parameters,
|
|
||||||
is_team_authorization=True,
|
|
||||||
workflow_app_id=app.id,
|
|
||||||
workflow_entities={
|
|
||||||
"app": app,
|
|
||||||
"workflow": workflow,
|
|
||||||
},
|
|
||||||
version=db_provider.version,
|
|
||||||
workflow_call_depth=0,
|
|
||||||
label=db_provider.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
|
|
||||||
"""
|
|
||||||
fetch tools from database
|
|
||||||
|
|
||||||
:param user_id: the user id
|
|
||||||
:param tenant_id: the tenant id
|
|
||||||
:return: the tools
|
|
||||||
"""
|
|
||||||
if self.tools is not None:
|
|
||||||
return self.tools
|
|
||||||
|
|
||||||
db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
|
|
||||||
WorkflowToolProvider.tenant_id == tenant_id,
|
|
||||||
WorkflowToolProvider.app_id == self.provider_id,
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not db_providers:
|
|
||||||
return []
|
|
||||||
|
|
||||||
app = db_providers.app
|
|
||||||
if not app:
|
|
||||||
raise ValueError("can not read app of workflow")
|
|
||||||
|
|
||||||
self.tools = [self._get_db_provider_tool(db_providers, app)]
|
|
||||||
|
|
||||||
return self.tools
|
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
|
||||||
"""
|
|
||||||
get tool by name
|
|
||||||
|
|
||||||
:param tool_name: the name of the tool
|
|
||||||
:return: the tool
|
|
||||||
"""
|
|
||||||
if self.tools is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for tool in self.tools:
|
|
||||||
if tool.identity.name == tool_name:
|
|
||||||
return tool
|
|
||||||
|
|
||||||
return None
|
|
||||||
Loading…
Reference in New Issue