feat/r2
jyong 10 months ago
parent e23d7e39ec
commit 81b07dc3be

@ -1051,11 +1051,12 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
.first() .first()
) )
if not log: if not log:
return {"datasource_info": None, return {
"datasource_type": None, "datasource_info": None,
"input_data": None, "datasource_type": None,
"datasource_node_id": None, "input_data": None,
}, 200 "datasource_node_id": None,
}, 200
return { return {
"datasource_info": json.loads(log.datasource_info), "datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type, "datasource_type": log.datasource_type,
@ -1086,5 +1087,6 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename") api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync") api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(DocumentPipelineExecutionLogApi, api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log") DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

@ -96,7 +96,7 @@ class DatasourceAuth(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json") parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()

@ -48,7 +48,8 @@ class DataSourceContentPreviewApi(Resource):
) )
return preview_content, 200 return preview_content, 200
api.add_resource( api.add_resource(
DataSourceContentPreviewApi, DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview" "/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
) )

@ -1,4 +1,3 @@
from ast import Str
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -128,14 +127,17 @@ class VariableEntity(BaseModel):
def convert_none_options(cls, v: Any) -> Sequence[str]: def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or [] return v or []
class RagPipelineVariableEntity(VariableEntity): class RagPipelineVariableEntity(VariableEntity):
""" """
Rag Pipeline Variable Entity. Rag Pipeline Variable Entity.
""" """
tooltips: Optional[str] = None tooltips: Optional[str] = None
placeholder: Optional[str] = None placeholder: Optional[str] = None
belong_to_node_id: str belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel): class ExternalDataVariableEntity(BaseModel):
""" """
External Data Variable Entity. External Data Variable Entity.

@ -1,5 +1,3 @@
from typing import Any
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from models.workflow import Workflow from models.workflow import Workflow

@ -13,6 +13,7 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
""" """
Pipeline Config Entity. Pipeline Config Entity.
""" """
rag_pipeline_variables: list[RagPipelineVariableEntity] = [] rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass pass

@ -47,6 +47,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
def _get_app_id(self) -> str: def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
""" """
Run application Run application
@ -114,9 +115,9 @@ class PipelineRunner(WorkflowBasedAppRunner):
for v in workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v) rag_pipeline_variable = RAGPipelineVariable(**v)
if ( if (
(rag_pipeline_variable.belong_to_node_id == self.application_generate_entity.start_node_id or rag_pipeline_variable.belong_to_node_id == "shared") rag_pipeline_variable.belong_to_node_id
and rag_pipeline_variable.variable in inputs in (self.application_generate_entity.start_node_id, "shared")
): ) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append( rag_pipeline_variables.append(
RAGPipelineVariableInput( RAGPipelineVariableInput(
variable=rag_pipeline_variable, variable=rag_pipeline_variable,

@ -10,8 +10,12 @@ from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import RAGPipelineVariableInput from core.variables.variables import RAGPipelineVariableInput
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, \ from core.workflow.constants import (
SYSTEM_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
RAG_PIPELINE_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
)
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from factories import variable_factory from factories import variable_factory

@ -462,6 +462,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
inputs=parameters_for_log, inputs=parameters_for_log,
) )
) )
@classmethod @classmethod
def version(cls) -> str: def version(cls) -> str:
return "1" return "1"

@ -323,13 +323,11 @@ class Workflow(Base):
return variables return variables
def rag_pipeline_user_input_form(self) -> list: def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node # get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables variables: list[Any] = self.rag_pipeline_variables
return variables return variables
@property @property
def unique_hash(self) -> str: def unique_hash(self) -> str:
""" """

@ -344,10 +344,10 @@ class DatasetService:
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise ValueError("Dataset not found") raise ValueError("Dataset not found")
# check if dataset name is exists # check if dataset name is exists
if ( if (
db.session.query(Dataset) db.session.query(Dataset)
.filter( .filter(
Dataset.id != dataset_id, Dataset.id != dataset_id,
Dataset.name == data.get("name", dataset.name), Dataset.name == data.get("name", dataset.name),
Dataset.tenant_id == dataset.tenant_id, Dataset.tenant_id == dataset.tenant_id,
@ -470,7 +470,7 @@ class DatasetService:
filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update Retrieval model # update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"] filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info # update icon info
if data.get("icon_info"): if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info") filtered_data["icon_info"] = data.get("icon_info")

@ -32,14 +32,10 @@ class DatasourceProviderService:
:param credentials: :param credentials:
""" """
# check name is exist # check name is exist
datasource_provider = ( datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first()
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name)
.first()
)
if datasource_provider: if datasource_provider:
raise ValueError("Authorization name is already exists") raise ValueError("Authorization name is already exists")
credential_valid = self.provider_manager.validate_provider_credentials( credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=current_user.id, user_id=current_user.id,

@ -20,9 +20,12 @@ from core.datasource.entities.datasource_entities import (
DatasourceProviderType, DatasourceProviderType,
GetOnlineDocumentPageContentRequest, GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage, OnlineDocumentPagesMessage,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
WebsiteCrawlMessage, WebsiteCrawlMessage,
) )
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import ( from core.rag.entities.event import (
BaseDatasourceEvent, BaseDatasourceEvent,
@ -31,8 +34,9 @@ from core.rag.entities.event import (
DatasourceProcessingEvent, DatasourceProcessingEvent,
) )
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput, Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import ( from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowNodeExecutionStatus, WorkflowNodeExecutionStatus,
@ -381,6 +385,17 @@ class RagPipelineService:
# run draft workflow node # run draft workflow node
start_at = time.perf_counter() start_at = time.perf_counter()
rag_pipeline_variables = []
if draft_workflow.rag_pipeline_variables:
for v in draft_workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.variable in user_inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=user_inputs[rag_pipeline_variable.variable],
)
)
workflow_node_execution = self._handle_node_run_result( workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run( getter=lambda: WorkflowEntry.single_step_run(
@ -388,6 +403,12 @@ class RagPipelineService:
node_id=node_id, node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
user_id=account.id, user_id=account.id,
variable_pool=VariablePool(
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
),
), ),
start_at=start_at, start_at=start_at,
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
@ -413,6 +434,17 @@ class RagPipelineService:
# run draft workflow node # run draft workflow node
start_at = time.perf_counter() start_at = time.perf_counter()
rag_pipeline_variables = []
if published_workflow.rag_pipeline_variables:
for v in published_workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if rag_pipeline_variable.variable in user_inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=user_inputs[rag_pipeline_variable.variable],
)
)
workflow_node_execution = self._handle_node_run_result( workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.single_step_run( getter=lambda: WorkflowEntry.single_step_run(
@ -420,6 +452,12 @@ class RagPipelineService:
node_id=node_id, node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
user_id=account.id, user_id=account.id,
variable_pool=VariablePool(
user_inputs=user_inputs,
environment_variables=published_workflow.environment_variables,
conversation_variables=published_workflow.conversation_variables,
rag_pipeline_variables=rag_pipeline_variables,
),
), ),
start_at=start_at, start_at=start_at,
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
@ -511,6 +549,33 @@ class RagPipelineService:
except Exception as e: except Exception as e:
logger.exception("Error during online document.") logger.exception("Error during online document.")
yield DatasourceErrorEvent(error=str(e)).model_dump() yield DatasourceErrorEvent(error=str(e)).model_dump()
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
online_drive_result: Generator[OnlineDriveBrowseFilesResponse, None, None] = datasource_runtime.online_drive_browse_files(
user_id=account.id,
request=OnlineDriveBrowseFilesRequest(
bucket=user_inputs.get("bucket"),
prefix=user_inputs.get("prefix"),
max_keys=user_inputs.get("max_keys", 20),
start_after=user_inputs.get("start_after"),
),
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
for message in online_drive_result:
end_time = time.time()
online_drive_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2),
total=None,
completed=None,
)
yield online_drive_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL: case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime) datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = ( website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
@ -631,7 +696,7 @@ class RagPipelineService:
except Exception as e: except Exception as e:
logger.exception("Error during get online document content.") logger.exception("Error during get online document content.")
raise RuntimeError(str(e)) raise RuntimeError(str(e))
#TODO Online Drive # TODO Online Drive
case _: case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}") raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e: except Exception as e:

@ -86,8 +86,9 @@ class ToolTransformService:
) )
else: else:
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, provider_type=provider.type.value,
icon=provider.declaration.identity.icon provider_name=provider.name,
icon=provider.declaration.identity.icon,
) )
@classmethod @classmethod

Loading…
Cancel
Save