Merge remote-tracking branch 'origin/feat/r2' into feat/r2

# Conflicts:
#	api/core/datasource/website_crawl/website_crawl_plugin.py
#	api/services/rag_pipeline/rag_pipeline.py
feat/datasource
jyong 10 months ago
commit 8d47d8ce4f

@ -1,6 +1,5 @@
import logging import logging
import yaml
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

@ -15,7 +15,7 @@ from core.plugin.entities.parameters import (
init_frontend_parameter, init_frontend_parameter,
) )
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolLabelEnum, ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum): class DatasourceProviderType(enum.StrEnum):
@ -207,12 +207,6 @@ class DatasourceInvokeFrom(Enum):
RAG_PIPELINE = "rag_pipeline" RAG_PIPELINE = "rag_pipeline"
class GetOnlineDocumentPagesRequest(BaseModel):
"""
Get online document pages request
"""
class OnlineDocumentPage(BaseModel): class OnlineDocumentPage(BaseModel):
""" """
Online document page Online document page
@ -237,7 +231,7 @@ class OnlineDocumentInfo(BaseModel):
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class GetOnlineDocumentPagesResponse(BaseModel): class OnlineDocumentPagesMessage(BaseModel):
""" """
Get online document pages response Get online document pages response
""" """
@ -290,17 +284,19 @@ class WebSiteInfo(BaseModel):
""" """
Website info Website info
""" """
job_id: str = Field(..., description="The job id") status: Optional[str] = Field(..., description="crawl job status")
status: str = Field(..., description="The status of the job")
web_info_list: Optional[list[WebSiteInfoDetail]] = [] web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
class GetWebsiteCrawlResponse(BaseModel):
""" """
Get website crawl response Get website crawl response
""" """
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
result: WebSiteInfo = WebSiteInfo(job_id="", status="", web_info_list=[]) class DatasourceMessage(ToolInvokeMessage):
pass
class DatasourceInvokeMessage(ToolInvokeMessage): class DatasourceInvokeMessage(ToolInvokeMessage):
@ -326,4 +322,4 @@ class DatasourceInvokeMessage(ToolInvokeMessage):
workspace_name: str = Field(..., description="The workspace name") workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon") workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents") total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any, Generator from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -8,8 +8,7 @@ from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage, DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, OnlineDocumentPagesMessage,
GetOnlineDocumentPagesResponse,
) )
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
@ -39,7 +38,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> Generator[DatasourceInvokeMessage, None, None]: ) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_online_document_pages( return manager.get_online_document_pages(

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any, Generator from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
@ -7,7 +7,7 @@ from core.datasource.entities.datasource_entities import (
DatasourceEntity, DatasourceEntity,
DatasourceInvokeMessage, DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.datasource import PluginDatasourceManager
@ -32,12 +32,12 @@ class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
self.icon = icon self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier self.plugin_unique_identifier = plugin_unique_identifier
def _get_website_crawl( def get_website_crawl(
self, self,
user_id: str, user_id: str,
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> Generator[DatasourceInvokeMessage, None, None]: ) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager() manager = PluginDatasourceManager()
return manager.get_website_crawl( return manager.get_website_crawl(

@ -1,4 +1,3 @@
from core.datasource.__base import datasource_provider
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType

@ -1,12 +1,11 @@
from collections.abc import Mapping from collections.abc import Generator, Mapping
from typing import Any, Generator from typing import Any
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage, DatasourceInvokeMessage,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse, OnlineDocumentPagesMessage,
GetOnlineDocumentPagesResponse, WebsiteCrawlMessage,
GetWebsiteCrawlResponse,
) )
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
@ -94,17 +93,17 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> Generator[DatasourceInvokeMessage, None, None]: ) -> Generator[WebsiteCrawlMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
datasource_provider_id = GenericProviderID(datasource_provider) datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream( return self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl", f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
DatasourceInvokeMessage, WebsiteCrawlMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {
@ -119,7 +118,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
yield from response
def get_online_document_pages( def get_online_document_pages(
self, self,
@ -130,7 +128,7 @@ class PluginDatasourceManager(BasePluginClient):
credentials: dict[str, Any], credentials: dict[str, Any],
datasource_parameters: Mapping[str, Any], datasource_parameters: Mapping[str, Any],
provider_type: str, provider_type: str,
) -> Generator[DatasourceInvokeMessage, None, None]: ) -> Generator[OnlineDocumentPagesMessage, None, None]:
""" """
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters. Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
""" """
@ -140,7 +138,7 @@ class PluginDatasourceManager(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages", f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
DatasourceInvokeMessage, OnlineDocumentPagesMessage,
data={ data={
"user_id": user_id, "user_id": user_id,
"data": { "data": {

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Generator, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -9,7 +9,6 @@ from core.datasource.entities.datasource_entities import (
DatasourceParameter, DatasourceParameter,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
GetOnlineDocumentPageContentResponse,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer

@ -1,7 +1,7 @@
import datetime import datetime
import logging import logging
from collections.abc import Mapping
import time import time
from collections.abc import Mapping
from typing import Any, cast from typing import Any, cast
from sqlalchemy import func from sqlalchemy import func

@ -17,12 +17,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import ( from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage, DatasourceInvokeMessage,
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPagesResponse, OnlineDocumentPagesMessage,
GetWebsiteCrawlResponse, WebsiteCrawlMessage,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.model_runtime.utils.encoders import jsonable_encoder
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -44,14 +43,14 @@ from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.account import Account from models.account import Account
from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore from models.dataset import Document, Pipeline, PipelineCustomizedTemplate # type: ignore
from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
from models.oauth import DatasourceProvider
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
WorkflowType, WorkflowNodeExecutionModel, WorkflowType,
) )
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.datasource_provider_service import DatasourceProviderService from services.datasource_provider_service import DatasourceProviderService
@ -424,6 +423,65 @@ class RagPipelineService:
return workflow_node_execution return workflow_node_execution
def run_datasource_workflow_node_status(
self, pipeline: Pipeline, node_id: str, job_id: str, account: Account, datasource_type: str, is_published: bool
) -> dict:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get('provider_name'),
plugin_id=datasource_node_data.get('plugin_id'),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_results: list[WebsiteCrawlMessage] = []
for website_message in datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters={"job_id": job_id},
provider_type=datasource_runtime.datasource_provider_type(),
):
website_crawl_results.append(website_message)
return {
"result": [result for result in website_crawl_results.result],
"status": website_crawl_results.result.status,
"provider_type": datasource_node_data.get("provider_type"),
}
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
def run_datasource_workflow_node( def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str, self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,

Loading…
Cancel
Save