feat/datasource
jyong 11 months ago
parent 0f10852b6b
commit ec1c4efca9

@ -20,7 +20,7 @@ class DatasourcePlugin(ABC):
self.runtime = runtime self.runtime = runtime
@abstractmethod @abstractmethod
def datasource_provider_type(self) -> DatasourceProviderType: def datasource_provider_type(self) -> str:
""" """
returns the type of the datasource provider returns the type of the datasource provider
""" """

@ -9,10 +9,10 @@ from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC): class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin entity: DatasourceProviderEntityWithPlugin | None
tenant_id: str tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None: def __init__(self, entity: DatasourceProviderEntityWithPlugin | None, tenant_id: str) -> None:
self.entity = entity self.entity = entity
self.tenant_id = tenant_id self.tenant_id = tenant_id

@ -24,5 +24,5 @@ class LocalFileDatasourcePlugin(DatasourcePlugin):
self.icon = icon self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> DatasourceProviderType: def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE return DatasourceProviderType.LOCAL_FILE

@ -69,5 +69,5 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type, provider_type=provider_type,
) )
def datasource_provider_type(self) -> DatasourceProviderType: def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT return DatasourceProviderType.ONLINE_DOCUMENT

@ -49,5 +49,5 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
provider_type=provider_type, provider_type=provider_type,
) )
def datasource_provider_type(self) -> DatasourceProviderType: def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL return DatasourceProviderType.WEBSITE_CRAWL

@ -10,7 +10,7 @@ class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderCon
plugin_unique_identifier: str plugin_unique_identifier: str
def __init__( def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str self, entity: DatasourceProviderEntityWithPlugin | None, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None: ) -> None:
super().__init__(entity, tenant_id) super().__init__(entity, tenant_id)
self.plugin_id = plugin_id self.plugin_id = plugin_id

@ -6,7 +6,7 @@ from core.datasource.entities.datasource_entities import (
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesResponse, GetOnlineDocumentPagesResponse,
GetWebsiteCrawlResponse, GetWebsiteCrawlResponse, DatasourceProviderEntity,
) )
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
@ -17,7 +17,7 @@ from core.plugin.impl.base import BasePluginClient
class PluginDatasourceManager(BasePluginClient): class PluginDatasourceManager(BasePluginClient):
def fetch_datasource_providers(self, tenant_id: str) -> list[DatasourceProviderApiEntity]: def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
""" """
Fetch datasource providers for the given tenant. Fetch datasource providers for the given tenant.
""" """
@ -46,12 +46,15 @@ class PluginDatasourceManager(BasePluginClient):
# for datasource in provider.declaration.datasources: # for datasource in provider.declaration.datasources:
# datasource.identity.provider = provider.declaration.identity.name # datasource.identity.provider = provider.declaration.identity.name
return [DatasourceProviderApiEntity(**self._get_local_file_datasource_provider())] return [PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())]
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity: def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
""" """
Fetch datasource provider for the given tenant and plugin. Fetch datasource provider for the given tenant and plugin.
""" """
if provider == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
tool_provider_id = ToolProviderID(provider) tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict: def transformer(json_response: dict[str, Any]) -> dict:
@ -218,6 +221,7 @@ class PluginDatasourceManager(BasePluginClient):
"X-Plugin-ID": tool_provider_id.plugin_id, "X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response: for resp in response:
@ -228,27 +232,48 @@ class PluginDatasourceManager(BasePluginClient):
def _get_local_file_datasource_provider(self) -> dict[str, Any]: def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return { return {
"id": "langgenius/file/file", "id": "langgenius/file/file",
"author": "langgenius",
"name": "langgenius/file/file",
"plugin_id": "langgenius/file", "plugin_id": "langgenius/file",
"provider": "langgenius",
"plugin_unique_identifier": "langgenius/file:0.0.1@dify", "plugin_unique_identifier": "langgenius/file:0.0.1@dify",
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"}, "declaration": {
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg", "identity": {
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"type": "datasource",
"team_credentials": {},
"is_team_authorization": False,
"allow_delete": True,
"datasources": [
{
"author": "langgenius", "author": "langgenius",
"name": "upload_file", "name": "langgenius/file/file",
"label": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File"}, "label": {
"description": {"en_US": "File", "zh_Hans": "File", "pt_BR": "File", "ja_JP": "File."}, "zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"description": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
},
"credentials_schema": [],
"provider_type": "local_file",
"datasources": [{
"identity": {
"author": "langgenius",
"name": "local_file",
"provider": "langgenius",
"label": {
"zh_Hans": "File",
"en_US": "File",
"pt_BR": "File",
"ja_JP": "File"
}
},
"parameters": [], "parameters": [],
"labels": ["search"], "description": {
"output_schema": None, "zh_Hans": "File",
} "en_US": "File",
], "pt_BR": "File",
"labels": ["search"], "ja_JP": "File"
}
}]
}
} }

@ -1,3 +1,3 @@
from .tool_node import ToolNode from .datasource_node import DatasourceNode
__all__ = ["DatasourceNode"] __all__ = ["DatasourceNode"]

@ -40,14 +40,19 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
node_data = cast(DatasourceNodeData, self.node_data) node_data = cast(DatasourceNodeData, self.node_data)
variable_pool = self.graph_runtime_state.variable_pool variable_pool = self.graph_runtime_state.variable_pool
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
if not datasource_type:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = datasource_type.value
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if not datasource_info:
raise DatasourceNodeError("Datasource info is not set")
datasource_info = datasource_info.value
# get datasource runtime # get datasource runtime
try: try:
from core.datasource.datasource_manager import DatasourceManager from core.datasource.datasource_manager import DatasourceManager
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
if datasource_type is None: if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set") raise DatasourceNodeError("Datasource type is not set")
@ -84,47 +89,55 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
) )
try: try:
if datasource_runtime.datasource_provider_type() == DatasourceProviderType.ONLINE_DOCUMENT: match datasource_type:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) case DatasourceProviderType.ONLINE_DOCUMENT:
online_document_result: GetOnlineDocumentPageContentResponse = ( datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
datasource_runtime._get_online_document_page_content( online_document_result: GetOnlineDocumentPageContentResponse = (
user_id=self.user_id, datasource_runtime._get_online_document_page_content(
datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters), user_id=self.user_id,
provider_type=datasource_runtime.datasource_provider_type(), datasource_parameters=GetOnlineDocumentPageContentRequest(**parameters),
provider_type=datasource_type,
)
) )
) yield RunCompletedEvent(
yield RunCompletedEvent( run_result=NodeRunResult(
run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=parameters_for_log,
inputs=parameters_for_log, metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, outputs={
outputs={ "online_document": online_document_result.result.model_dump(),
"online_document": online_document_result.result.model_dump(), "datasource_type": datasource_type,
"datasource_type": datasource_runtime.datasource_provider_type, },
}, )
) )
) case DatasourceProviderType.WEBSITE_CRAWL | DatasourceProviderType.LOCAL_FILE:
elif ( yield RunCompletedEvent(
datasource_runtime.datasource_provider_type in ( run_result=NodeRunResult(
DatasourceProviderType.WEBSITE_CRAWL, status=WorkflowNodeExecutionStatus.SUCCEEDED,
DatasourceProviderType.LOCAL_FILE, inputs=parameters_for_log,
) metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
): outputs={
yield RunCompletedEvent( "website": datasource_info,
run_result=NodeRunResult( "datasource_type": datasource_type,
status=WorkflowNodeExecutionStatus.SUCCEEDED, },
inputs=parameters_for_log, )
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, )
outputs={ case DatasourceProviderType.LOCAL_FILE:
"website": datasource_info, yield RunCompletedEvent(
"datasource_type": datasource_runtime.datasource_provider_type, run_result=NodeRunResult(
}, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": datasource_info,
"datasource_type": datasource_runtime.datasource_provider_type,
},
)
)
case _:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
) )
)
else:
raise DatasourceNodeError(
f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}"
)
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
@ -170,23 +183,24 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters} datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {} result: dict[str, Any] = {}
for parameter_name in node_data.datasource_parameters: if node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name) for parameter_name in node_data.datasource_parameters:
if not parameter: parameter = datasource_parameters_dictionary.get(parameter_name)
result[parameter_name] = None if not parameter:
continue result[parameter_name] = None
datasource_input = node_data.datasource_parameters[parameter_name] continue
if datasource_input.type == "variable": datasource_input = node_data.datasource_parameters[parameter_name]
variable = variable_pool.get(datasource_input.value) if datasource_input.type == "variable":
if variable is None: variable = variable_pool.get(datasource_input.value)
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist") if variable is None:
parameter_value = variable.value raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
elif datasource_input.type in {"mixed", "constant"}: parameter_value = variable.value
segment_group = variable_pool.convert_template(str(datasource_input.value)) elif datasource_input.type in {"mixed", "constant"}:
parameter_value = segment_group.log if for_log else segment_group.text segment_group = variable_pool.convert_template(str(datasource_input.value))
else: parameter_value = segment_group.log if for_log else segment_group.text
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'") else:
result[parameter_name] = parameter_value raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result return result

@ -1,4 +1,4 @@
from typing import Any, Literal, Union from typing import Any, Literal, Union, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo from pydantic_core.core_schema import ValidationInfo
@ -9,30 +9,17 @@ from core.workflow.nodes.base.entities import BaseNodeData
class DatasourceEntity(BaseModel): class DatasourceEntity(BaseModel):
provider_id: str provider_id: str
provider_name: str # redundancy provider_name: str # redundancy
datasource_name: str provider_type: str
tool_label: str # redundancy datasource_name: Optional[str] = "local_file"
datasource_configurations: dict[str, Any] datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel): class DatasourceInput(BaseModel):
# TODO: check this type # TODO: check this type
value: Union[Any, list[str]] value: Optional[Union[Any, list[str]]] = None
type: Literal["mixed", "variable", "constant"] type: Optional[Literal["mixed", "variable", "constant"]] = None
@field_validator("type", mode="before") @field_validator("type", mode="before")
@classmethod @classmethod
@ -51,4 +38,4 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
raise ValueError("value must be a string, int, float, or bool") raise ValueError("value must be a string, int, float, or bool")
return typ return typ
datasource_parameters: dict[str, DatasourceInput] datasource_parameters: dict[str, DatasourceInput] | None = None

@ -19,6 +19,7 @@ from .entities import KnowledgeIndexNodeData
from .exc import ( from .exc import (
KnowledgeIndexNodeError, KnowledgeIndexNodeError,
) )
from ..base import BaseNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +32,7 @@ default_retrieval_model = {
} }
class KnowledgeIndexNode(LLMNode): class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
_node_data_cls = KnowledgeIndexNodeData # type: ignore _node_data_cls = KnowledgeIndexNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_INDEX _node_type = NodeType.KNOWLEDGE_INDEX
@ -44,7 +45,7 @@ class KnowledgeIndexNode(LLMNode):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
error="Query variable is not object type.", error="Index chunk variable is not object type.",
) )
chunks = variable.value chunks = variable.value
variables = {"chunks": chunks} variables = {"chunks": chunks}

@ -4,12 +4,14 @@ from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.code import CodeNode from core.workflow.nodes.code import CodeNode
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end import EndNode from core.workflow.nodes.end import EndNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request import HttpRequestNode from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.if_else import IfElseNode from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.list_operator import ListOperatorNode from core.workflow.nodes.list_operator import ListOperatorNode
from core.workflow.nodes.llm import LLMNode from core.workflow.nodes.llm import LLMNode
@ -119,4 +121,12 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: AgentNode, LATEST_VERSION: AgentNode,
"1": AgentNode, "1": AgentNode,
}, },
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,
},
NodeType.KNOWLEDGE_INDEX: {
LATEST_VERSION: KnowledgeIndexNode,
"1": KnowledgeIndexNode,
},
} }

Loading…
Cancel
Save