From 2f70398b3d6087159bfacc37a5118fe3a6277253 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 30 May 2025 06:16:03 +0800 Subject: [PATCH] refactor: use typed retrieval context Signed-off-by: -LAN- --- .../advanced_chat/generate_task_pipeline.py | 28 +++---- api/core/app/entities/queue_entities.py | 5 +- api/core/app/entities/task_entities.py | 23 +++++- .../easy_ui_based_generate_task_pipeline.py | 20 ++--- .../task_pipeline/message_cycle_manager.py | 15 ++-- .../index_tool_callback_handler.py | 5 +- api/core/rag/entities/citation_metadata.py | 23 ++++++ api/core/rag/retrieval/dataset_retrieval.py | 63 ++++++++-------- .../dataset_multi_retriever_tool.py | 39 +++++----- .../dataset_retriever_tool.py | 73 ++++++++++--------- .../workflow/graph_engine/entities/event.py | 5 +- api/core/workflow/nodes/event/event.py | 4 +- api/core/workflow/nodes/llm/node.py | 41 ++++++----- 13 files changed, 191 insertions(+), 153 deletions(-) create mode 100644 api/core/rag/entities/citation_metadata.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 237089df45..8722562280 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator, Mapping @@ -60,7 +59,6 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey @@ -609,18 +607,14 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueAnnotationReplyEvent): self._message_cycle_manager.handle_annotation_reply(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: message = self._get_message(session=session) message.answer = self._task_state.answer message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( message_id=message.id, @@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer_price_unit = usage.completion_price_unit message.total_price = usage.total_price message.currency = usage.currency - self._task_state.metadata["usage"] = jsonable_encoder(usage) + self._task_state.metadata.usage = usage else: - self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) + self._task_state.metadata.usage = LLMUsage.empty_usage() message_was_created.send( message, application_generate_entity=self._application_generate_entity, @@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline: Message end to stream response. :return: """ - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata.copy() + extras = self._task_state.metadata.model_dump() - if "annotation_reply" in extras["metadata"]: - del extras["metadata"]["annotation_reply"] + if self._task_state.metadata.annotation_reply: + del extras["annotation_reply"] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, files=self._recorded_files, - metadata=extras.get("metadata", {}), + metadata=extras, ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc8573013a..42e6a1519c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional @@ -6,6 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES - retriever_resources: list[dict] + retriever_resources: Sequence[RetrievalSourceMetadata] in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" in_loop_id: Optional[str] = None diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 70748b085d..25c889e922 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +class AnnotationReplyAccount(BaseModel): + id: str + name: str + + +class AnnotationReply(BaseModel): + id: str + account: AnnotationReplyAccount + + +class TaskStateMetadata(BaseModel): + annotation_reply: AnnotationReply | None = None + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) + usage: LLMUsage | None = None + + class TaskState(BaseModel): """ TaskState entity """ - metadata: dict = {} + metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) class EasyUITaskState(TaskState): diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 6156b2973e..1ea50a5778 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator @@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -141,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} + extras = {"usage": self._task_state.llm_result.usage.model_dump()} if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -379,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): message.provider_response_latency = time.perf_counter() - self._start_at message.total_price = usage.total_price message.currency = usage.currency - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: trace_manager.add_trace_task( @@ -430,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): Message end to stream response. :return: """ - self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) - - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata - + self._task_state.metadata.usage = self._task_state.llm_result.usage + metadata_dict = self._task_state.metadata.model_dump() return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, - metadata=extras.get("metadata", {}), + metadata=metadata_dict, ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 8d762a6655..2343081eaf 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -17,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, @@ -111,10 +113,13 @@ class MessageCycleManager: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata["annotation_reply"] = { - "id": annotation.id, - "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, - } + self._task_state.metadata.annotation_reply = AnnotationReply( + id=annotation.id, + account=AnnotationReplyAccount( + id=annotation.account_id, + name=account.name if account else "Dify user", + ), + ) return annotation @@ -127,7 +132,7 @@ class MessageCycleManager: :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata["retriever_resources"] = event.retriever_resources + self._task_state.metadata.retriever_resources = event.retriever_resources def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 13c22213c4..a3a7b4b812 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,8 +1,10 @@ import logging +from collections.abc import Sequence from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from extensions.ext_database import db @@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler: db.session.commit() - def return_retriever_resource_info(self, resource: list): + # TODO(-LAN-): Improve type check + def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py new file mode 100644 index 0000000000..00120425c9 --- /dev/null +++ b/api/core/rag/entities/citation_metadata.py @@ -0,0 +1,23 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + +class RetrievalSourceMetadata(BaseModel): + position: Optional[int] = None + dataset_id: Optional[str] = None + dataset_name: Optional[str] = None + document_id: Optional[str] = None + document_name: Optional[str] = None + data_source_type: Optional[str] = None + segment_id: Optional[str] = None + retriever_from: Optional[str] = None + score: Optional[float] = None + hit_count: Optional[int] = None + word_count: Optional[int] = None + segment_position: Optional[int] = None + index_node_hash: Optional[str] = None + content: Optional[str] = None + page: Optional[int] = None + doc_metadata: Optional[dict[str, Any]] = None + title: Optional[str] = None diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index c4adf6de4d..f668148428 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.index_processor.constant.index_type import IndexType @@ -198,21 +199,21 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] - document_context_list = [] - retrieval_resource_list = [] + document_context_list: list[DocumentContext] = [] + retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) - source = { - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": invoke_from.to_source(), - "score": item.metadata.get("score"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=invoke_from.to_source(), + score=item.metadata.get("score"), + content=item.page_content, + ) retrieval_resource_list.append(source) # deal with dify documents if dify_documents: @@ -248,32 +249,32 @@ class DatasetRetrieval: .first() ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": invoke_from.to_source(), - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=invoke_from.to_source(), + score=record.score or 0.0, + doc_metadata=document.doc_metadata, + ) if invoke_from.to_source() == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: - retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) for position, item in enumerate(retrieval_resource_list, start=1): - item["position"] = position + item.position = position hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 04437ea6d8..93d3fcc49d 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): else: document_context_list.append(segment.get_sign_content()) if self.return_resource: - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() @@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): .first() ) if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + position=resource_number, + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=self.retriever_from, + score=document_score_list.get(segment.index_node_id, None), + doc_metadata=document.doc_metadata, + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content context_list.append(source) resource_number += 1 diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 7b6882ed52..ff1d9021ce 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -14,7 +15,7 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else: document_ids_filter = None if dataset.provider == "external": - results = [] + results: list[RetrievalDocument] = [] external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document.metadata["dataset_name"] = dataset.name results.append(document) # deal with external documents - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] for position, item in enumerate(results, start=1): if item.metadata is not None: - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + position=position, + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=self.retriever_from, + score=item.metadata.get("score"), + title=item.metadata.get("title"), + content=item.page_content, + ) context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return "" # get retrieval model , if the model is not setting , using default retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model - retrieval_resource_list = [] + retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] + document_context_list: list[DocumentContext] = [] records = RetrievalService.format_retrieval_documents(documents) if records: for record in records: @@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): .first() ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, # type: ignore - "document_name": document.name, # type: ignore - "data_source_type": document.data_source_type, # type: ignore - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, # type: ignore - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, # type: ignore + document_name=document.name, # type: ignore + data_source_type=document.data_source_type, # type: ignore + segment_id=segment.id, + retriever_from=self.retriever_from, + score=record.score or 0.0, + doc_metadata=document.doc_metadata, # type: ignore + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x.get("score") or 0.0, + key=lambda x: x.score or 0.0, reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore - item["position"] = position # type: ignore + item.position = position # type: ignore for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 689a07c4f6..9a4939502e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,9 +1,10 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Optional from pydantic import BaseModel, Field +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType @@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index f45919caf5..b72d111f49 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,8 +1,10 @@ +from collections.abc import Sequence from datetime import datetime from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus @@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 38e4f7af01..0fd7c31ffb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.variables import ( ArrayAnySegment, ArrayFileSegment, @@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) elif isinstance(context_value_variable, ArraySegment): context_str = "" - original_retriever_resource = [] + original_retriever_resource: list[RetrievalSourceMetadata] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]): retriever_resources=original_retriever_resource, context=context_str.strip() ) - def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + def _convert_to_original_retriever_resource(self, context_dict: dict): if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]): ): metadata = context_dict.get("metadata", {}) - source = { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - } + source = RetrievalSourceMetadata( + position=metadata.get("position"), + dataset_id=metadata.get("dataset_id"), + dataset_name=metadata.get("dataset_name"), + document_id=metadata.get("document_id"), + document_name=metadata.get("document_name"), + data_source_type=metadata.get("data_source_type"), + segment_id=metadata.get("segment_id"), + retriever_from=metadata.get("retriever_from"), + score=metadata.get("score"), + hit_count=metadata.get("segment_hit_count"), + word_count=metadata.get("segment_word_count"), + segment_position=metadata.get("segment_position"), + index_node_hash=metadata.get("segment_index_node_hash"), + content=context_dict.get("content"), + page=metadata.get("page"), + doc_metadata=metadata.get("doc_metadata"), + ) return source