feat(datasource): change datasource result type to event-stream

feat/datasource
Dongyu Li 11 months ago
parent e51d308312
commit 224111081b

@ -414,16 +414,19 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type") raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
result = rag_pipeline_service.run_datasource_workflow_node( return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node(
pipeline=pipeline, pipeline=pipeline,
node_id=node_id, node_id=node_id,
user_inputs=inputs, user_inputs=inputs,
account=current_user, account=current_user,
datasource_type=datasource_type, datasource_type=datasource_type,
is_published=True, is_published=False,
)
)
) )
return result
class RagPipelineDraftDatasourceNodeRunApi(Resource): class RagPipelineDraftDatasourceNodeRunApi(Resource):
@ -455,7 +458,6 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
raise ValueError("missing datasource_type") raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService() rag_pipeline_service = RagPipelineService()
try:
return helper.compact_generate_response( return helper.compact_generate_response(
PipelineGenerator.convert_to_event_stream( PipelineGenerator.convert_to_event_stream(
rag_pipeline_service.run_datasource_workflow_node( rag_pipeline_service.run_datasource_workflow_node(
@ -468,8 +470,6 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
) )
) )
) )
except Exception as e:
print(e)
class RagPipelinePublishedNodeRunApi(Resource): class RagPipelinePublishedNodeRunApi(Resource):

@ -2,14 +2,14 @@ import contextvars
import datetime import datetime
import json import json
import logging import logging
import random import secrets
import threading import threading
import time import time
import uuid import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -110,7 +110,7 @@ class PipelineGenerator(BaseAppGenerator):
start_node_id: str = args["start_node_id"] start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"] datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
documents = [] documents = []
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list: for datasource_info in datasource_info_list:

@ -19,9 +19,9 @@ class BaseDatasourceEvent(BaseModel):
class DatasourceCompletedEvent(BaseDatasourceEvent): class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str,Any] | list = Field(..., description="result") data: Mapping[str,Any] | list = Field(..., description="result")
total: Optional[int] = Field(..., description="total") total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(..., description="completed") completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(..., description="time consuming") time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent): class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value event: str = DatasourceStreamEvent.PROCESSING.value

@ -558,6 +558,7 @@ class RagPipelineService:
provider_type=datasource_runtime.datasource_provider_type(), provider_type=datasource_runtime.datasource_provider_type(),
) )
start_time = time.time() start_time = time.time()
try:
for message in website_crawl_result: for message in website_crawl_result:
end_time = time.time() end_time = time.time()
if message.result.status == "completed": if message.result.status == "completed":
@ -573,6 +574,8 @@ class RagPipelineService:
completed=message.result.completed, completed=message.result.completed,
) )
yield crawl_event.model_dump() yield crawl_event.model_dump()
except Exception as e:
print(str(e))
case _: case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")

Loading…
Cancel
Save