feat: webcrawl

feat/datasource
Harry 10 months ago
parent b2b95412b9
commit 0908f310fc

@ -1,6 +1,5 @@
import logging import logging
import yaml
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

@ -15,7 +15,7 @@ from core.plugin.entities.parameters import (
init_frontend_parameter, init_frontend_parameter,
) )
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage from core.tools.entities.tool_entities import ToolLabelEnum
class DatasourceProviderType(enum.StrEnum): class DatasourceProviderType(enum.StrEnum):
@ -290,40 +290,13 @@ class WebSiteInfo(BaseModel):
""" """
Website info Website info
""" """
job_id: str = Field(..., description="The job id") status: Optional[str] = Field(..., description="crawl job status")
status: str = Field(..., description="The status of the job")
web_info_list: Optional[list[WebSiteInfoDetail]] = [] web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
class GetWebsiteCrawlResponse(BaseModel):
""" """
Get website crawl response Get website crawl response
""" """
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[])
class DatasourceInvokeMessage(ToolInvokeMessage):
"""
Datasource Invoke Message.
"""
class WebsiteCrawlMessage(BaseModel):
"""
Website crawl message
"""
job_id: str = Field(..., description="The job id")
status: str = Field(..., description="The status of the job")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
class OnlineDocumentMessage(BaseModel):
"""
Online document message
"""
workspace_id: str = Field(..., description="The workspace id")
workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")

@ -1,4 +1,4 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
@ -6,7 +6,7 @@ from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceEntity, DatasourceEntity,
DatasourceProviderType, DatasourceProviderType,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
@ -31,12 +31,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
self.icon = icon self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier self.plugin_unique_identifier = plugin_unique_identifier
def _get_website_crawl( def get_website_crawl(
self, self,
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> GetWebsiteCrawlResponse: ) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_website_crawl( return manager.get_website_crawl(

@ -1,4 +1,3 @@
from core.datasource.__base import datasource_provider
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType

@ -1,12 +1,12 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any, Generator from typing import Any
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage, DatasourceInvokeMessage,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, GetOnlineDocumentPageContentResponse,
GetOnlineDocumentPagesResponse, GetOnlineDocumentPagesResponse,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
@ -94,17 +94,17 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> Generator[DatasourceInvokeMessage, None, None]: ) -> Generator[WebsiteCrawlMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
datasource_provider_id = GenericProviderID(datasource_provider) datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream( return self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
DatasourceInvokeMessage, WebsiteCrawlMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
@ -119,7 +119,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
yield from response
def get_online_document_pages( def get_online_document_pages(
self, self,

@ -1,7 +1,7 @@
import datetime import datetime
import logging import logging
from collections.abc import Mapping
import time import time
from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy import func from sqlalchemy import func

@ -17,11 +17,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPagesResponse, GetOnlineDocumentPagesResponse,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -43,14 +42,14 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.oauth import DatasourceProvider
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowType, WorkflowNodeExecutionModel, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -468,15 +467,16 @@ class RagPipelineService:
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_results: list[WebsiteCrawlMessage] = []
for website_message in datasource_runtime.get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters={"job_id": job_id}, datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) ):
website_crawl_results.append(website_message)
return { return {
"result": [result for result in website_crawl_result.result], "result": [result for result in website_crawl_results.result],
"job_id": website_crawl_result.result.job_id, "status": website_crawl_results.result.status,
"status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"), "provider_type": datasource_node_data.get("provider_type"),
} }
case _: case _:
@ -544,14 +544,15 @@ class RagPipelineService:
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: GetWebsiteCrawlResponse = datasource_runtime._get_website_crawl( website_crawl_results: list[WebsiteCrawlMessage] = []
for website_crawl_result in datasource_runtime.get_website_crawl(
user_id=account.id, user_id=account.id,
datasource_parameters=user_inputs, datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) ):
website_crawl_results.append(website_crawl_result)
return { return {
"result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [], "result": [result.model_dump() for result in website_crawl_result.result.web_info_list] if website_crawl_result.result.web_info_list else [],
"job_id": website_crawl_result.result.job_id,
"status": website_crawl_result.result.status, "status": website_crawl_result.result.status,
"provider_type": datasource_node_data.get("provider_type"), "provider_type": datasource_node_data.get("provider_type"),
} }

Loading…
Cancel
Save