feat/datasource
jyong 11 months ago
parent 64d997fdb0
commit 42fcda3dc8

@ -50,8 +50,8 @@ class PipelineTemplateDetailApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@enterprise_license_required @enterprise_license_required
def get(self, pipeline_id: str): def get(self, template_id: str):
pipeline_template = RagPipelineService.get_pipeline_template_detail(pipeline_id) pipeline_template = RagPipelineService.get_pipeline_template_detail(template_id)
return pipeline_template, 200 return pipeline_template, 200
@ -120,7 +120,7 @@ api.add_resource(
) )
api.add_resource( api.add_resource(
PipelineTemplateDetailApi, PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:pipeline_id>", "/rag/pipeline/templates/<string:template_id>",
) )
api.add_resource( api.add_resource(
CustomizedPipelineTemplateApi, CustomizedPipelineTemplateApi,

@ -4,6 +4,8 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, ValidationInfo, field_validator from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import ( from core.plugin.entities.parameters import (
PluginParameter, PluginParameter,
PluginParameterOption, PluginParameterOption,
@ -13,7 +15,7 @@ from core.plugin.entities.parameters import (
init_frontend_parameter, init_frontend_parameter,
) )
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity from core.tools.entities.tool_entities import ToolLabelEnum, ToolProviderEntity
class DatasourceProviderType(enum.StrEnum): class DatasourceProviderType(enum.StrEnum):
@ -118,29 +120,36 @@ class DatasourceIdentity(BaseModel):
icon: Optional[str] = None icon: Optional[str] = None
class DatasourceDescription(BaseModel):
human: I18nObject = Field(..., description="The description presented to the user")
llm: str = Field(..., description="The description presented to the LLM")
class DatasourceEntity(BaseModel): class DatasourceEntity(BaseModel):
identity: DatasourceIdentity identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list) parameters: list[DatasourceParameter] = Field(default_factory=list)
description: Optional[DatasourceDescription] = None description: I18nObject = Field(..., description="The label of the datasource")
output_schema: Optional[dict] = None output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
@field_validator("parameters", mode="before") @field_validator("parameters", mode="before")
@classmethod @classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or [] return v or []
class DatasourceProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class DatasourceProviderEntity(ToolProviderEntity): class DatasourceProviderEntity(BaseModel):
""" """
Datasource provider entity Datasource provider entity
""" """
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
provider_type: DatasourceProviderType provider_type: DatasourceProviderType
@ -202,7 +211,6 @@ class GetOnlineDocumentPagesRequest(BaseModel):
Get online document pages request Get online document pages request
""" """
tenant_id: str = Field(..., description="The tenant id")
class OnlineDocumentPageIcon(BaseModel): class OnlineDocumentPageIcon(BaseModel):
@ -276,8 +284,6 @@ class GetWebsiteCrawlRequest(BaseModel):
""" """
Get website crawl request Get website crawl request
""" """
url: str = Field(..., description="The url of the website")
crawl_parameters: dict = Field(..., description="The crawl parameters") crawl_parameters: dict = Field(..., description="The crawl parameters")
@ -297,4 +303,4 @@ class GetWebsiteCrawlResponse(BaseModel):
Get website crawl response Get website crawl response
""" """
result: WebSiteInfo result: list[WebSiteInfo]

@ -1,3 +1,4 @@
from typing import Any, Mapping
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
@ -34,7 +35,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def _get_online_document_pages( def _get_online_document_pages(
self, self,
user_id: str, user_id: str,
datasource_parameters: GetOnlineDocumentPagesRequest, datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPagesResponse: ) -> GetOnlineDocumentPagesResponse:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()

@ -1,3 +1,4 @@
from typing import Any, Mapping
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
@ -32,7 +33,7 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
def _get_website_crawl( def _get_website_crawl(
self, self,
user_id: str, user_id: str,
datasource_parameters: GetWebsiteCrawlRequest, datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> GetWebsiteCrawlResponse:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()

@ -52,7 +52,6 @@ class PluginDatasourceProviderEntity(BaseModel):
provider: str provider: str
plugin_unique_identifier: str plugin_unique_identifier: str
plugin_id: str plugin_id: str
author: str
declaration: DatasourceProviderEntityWithPlugin declaration: DatasourceProviderEntityWithPlugin

@ -1,12 +1,10 @@
from typing import Any from typing import Any, Mapping
from core.datasource.entities.api_entities import DatasourceProviderApiEntity from core.datasource.entities.api_entities import DatasourceProviderApiEntity
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesRequest,
GetOnlineDocumentPagesResponse, GetOnlineDocumentPagesResponse,
GetWebsiteCrawlRequest,
GetWebsiteCrawlResponse, GetWebsiteCrawlResponse,
) )
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
@ -86,7 +84,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider: str, datasource_provider: str,
datasource_name: str, datasource_name: str,
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: GetWebsiteCrawlRequest, datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> GetWebsiteCrawlResponse:
""" """
@ -125,7 +123,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider: str, datasource_provider: str,
datasource_name: str, datasource_name: str,
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: GetOnlineDocumentPagesRequest, datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetOnlineDocumentPagesResponse: ) -> GetOnlineDocumentPagesResponse:
""" """

@ -67,15 +67,15 @@ class RagPipelineService:
return result.get("pipeline_templates") return result.get("pipeline_templates")
@classmethod @classmethod
def get_pipeline_template_detail(cls, pipeline_id: str) -> Optional[dict]: def get_pipeline_template_detail(cls, template_id: str) -> Optional[dict]:
""" """
Get pipeline template detail. Get pipeline template detail.
:param pipeline_id: pipeline id :param template_id: template id
:return: :return:
""" """
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode) retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(pipeline_id) result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
return result return result
@classmethod @classmethod
@ -427,7 +427,7 @@ class RagPipelineService:
online_document_result: GetOnlineDocumentPagesResponse = ( online_document_result: GetOnlineDocumentPagesResponse = (
datasource_runtime._get_online_document_pages( datasource_runtime._get_online_document_pages(
user_id=account.id, user_id=account.id,
datasource_parameters=GetOnlineDocumentPagesRequest(tenant_id=pipeline.tenant_id), datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
) )
@ -440,11 +440,11 @@ class RagPipelineService:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters=GetWebsiteCrawlRequest(**user_inputs), datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
return { return {
"result": website_crawl_result.result.model_dump(), "result": [result.model_dump() for result in website_crawl_result.result],
"provider_type": datasource_node_data.get("provider_type"), "provider_type": datasource_node_data.get("provider_type"),
} }
else: else:

Loading…
Cancel
Save