feat/rag-2
jyong 11 months ago
parent f2960989c1
commit 44c2efcfe4

@ -22,7 +22,7 @@ class WorkflowVariablesConfigManager:
return variables return variables
@classmethod @classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]: def convert_rag_pipeline_variable(cls, workflow: Workflow, start_node_id: str) -> list[RagPipelineVariableEntity]:
""" """
Convert workflow start variables to variables Convert workflow start variables to variables
@ -31,8 +31,9 @@ class WorkflowVariablesConfigManager:
variables = [] variables = []
user_input_form = workflow.rag_pipeline_user_input_form() user_input_form = workflow.rag_pipeline_user_input_form()
# variables # filter variables by start_node_id
for variable in user_input_form: for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable)) if variable.get("belong_to_node_id") == start_node_id or variable.get("belong_to_node_id") == "shared":
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables return variables

@ -20,13 +20,13 @@ class PipelineConfig(WorkflowUIBasedAppConfig):
class PipelineConfigManager(BaseAppConfigManager): class PipelineConfigManager(BaseAppConfigManager):
@classmethod @classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig: def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow, start_node_id: str) -> PipelineConfig:
pipeline_config = PipelineConfig( pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id, tenant_id=pipeline.tenant_id,
app_id=pipeline.id, app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE, app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id, workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow), rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow, start_node_id=start_node_id),
) )
return pipeline_config return pipeline_config

@ -29,6 +29,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
@ -97,11 +98,6 @@ class PipelineGenerator(BaseAppGenerator):
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
)
# Add null check for dataset # Add null check for dataset
dataset = pipeline.dataset dataset = pipeline.dataset
if not dataset: if not dataset:
@ -111,6 +107,12 @@ class PipelineGenerator(BaseAppGenerator):
datasource_type: str = args["datasource_type"] datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"] datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000) batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
start_node_id=start_node_id
)
documents = [] documents = []
if invoke_from == InvokeFrom.PUBLISHED: if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list: for datasource_info in datasource_info_list:
@ -308,6 +310,9 @@ class PipelineGenerator(BaseAppGenerator):
worker_thread.start() worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
# return response or stream generator # return response or stream generator
response = self._handle_response( response = self._handle_response(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
@ -317,6 +322,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming, stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
) )
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -347,7 +353,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("inputs is required") raise ValueError("inputs is required")
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
dataset = pipeline.dataset dataset = pipeline.dataset
if not dataset: if not dataset:
@ -432,7 +440,9 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required") raise ValueError("Pipeline dataset is required")
# convert to app config # convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow) pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline,
workflow=workflow,
start_node_id=args.get("start_node_id","shared"))
# init application generate entity # init application generate entity
application_generate_entity = RagPipelineGenerateEntity( application_generate_entity = RagPipelineGenerateEntity(
@ -476,7 +486,7 @@ class PipelineGenerator(BaseAppGenerator):
return self._generate( return self._generate(
flask_app=current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline, pipeline=pipeline,
workflow=workflow, workflow_id=workflow.id,
user=user, user=user,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
@ -539,6 +549,7 @@ class PipelineGenerator(BaseAppGenerator):
user: Union[Account, EndUser], user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository, workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False, stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
""" """
@ -560,6 +571,7 @@ class PipelineGenerator(BaseAppGenerator):
stream=stream, stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
) )
try: try:

Loading…
Cancel
Save