refactor: use typed retrieval context

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20460/head
-LAN- 12 months ago
parent d595f74a3d
commit 2f70398b3d
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,4 +1,3 @@
import json
import logging import logging
import time import time
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
@ -60,7 +59,6 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -609,18 +607,14 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit() session.commit()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._message_cycle_manager.handle_annotation_reply(event) self._message_cycle_manager.handle_annotation_reply(event)
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
message = self._get_message(session=session) message = self._get_message(session=session)
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
session.commit() session.commit()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
@ -683,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline:
message = self._get_message(session=session) message = self._get_message(session=session)
message.answer = self._task_state.answer message.answer = self._task_state.answer
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=message.id, message_id=message.id,
@ -713,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline:
message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage) self._task_state.metadata.usage = usage
else: else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send( message_was_created.send(
message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
@ -726,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline:
Message end to stream response. Message end to stream response.
:return: :return:
""" """
extras = {} extras = self._task_state.metadata.model_dump()
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata.copy()
if "annotation_reply" in extras["metadata"]: if self._task_state.metadata.annotation_reply:
del extras["metadata"]["annotation_reply"] del extras["annotation_reply"]
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
files=self._recorded_files, files=self._recorded_files,
metadata=extras.get("metadata", {}), metadata=extras,
) )
def _handle_output_moderation_chunk(self, text: str) -> bool: def _handle_output_moderation_chunk(self, text: str) -> bool:

@ -1,4 +1,4 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
@ -6,6 +6,7 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
@ -283,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
""" """
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict] retriever_resources: Sequence[RetrievalSourceMetadata]
in_iteration_id: Optional[str] = None in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None

@ -2,20 +2,37 @@ from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
id: str
name: str
class AnnotationReply(BaseModel):
id: str
account: AnnotationReplyAccount
class TaskStateMetadata(BaseModel):
annotation_reply: AnnotationReply | None = None
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
usage: LLMUsage | None = None
class TaskState(BaseModel): class TaskState(BaseModel):
""" """
TaskState entity TaskState entity
""" """
metadata: dict = {} metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
class EasyUITaskState(TaskState): class EasyUITaskState(TaskState):

@ -1,4 +1,3 @@
import json
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator
@ -51,7 +50,6 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
) )
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -141,9 +139,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
if isinstance(stream_response, ErrorStreamResponse): if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse): elif isinstance(stream_response, MessageEndStreamResponse):
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} extras = {"usage": self._task_state.llm_result.usage.model_dump()}
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata extras["metadata"] = self._task_state.metadata.model_dump()
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation_mode == AppMode.COMPLETION.value: if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
@ -379,9 +377,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price message.total_price = usage.total_price
message.currency = usage.currency message.currency = usage.currency
message.message_metadata = ( message.message_metadata = self._task_state.metadata.model_dump_json()
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -430,16 +426,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
Message end to stream response. Message end to stream response.
:return: :return:
""" """
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
extras = {}
if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message_id, id=self._message_id,
metadata=extras.get("metadata", {}), metadata=metadata_dict,
) )
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:

@ -17,6 +17,8 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent, QueueRetrieverResourcesEvent,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
AnnotationReply,
AnnotationReplyAccount,
EasyUITaskState, EasyUITaskState,
MessageFileStreamResponse, MessageFileStreamResponse,
MessageReplaceStreamResponse, MessageReplaceStreamResponse,
@ -111,10 +113,13 @@ class MessageCycleManager:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation: if annotation:
account = annotation.account account = annotation.account
self._task_state.metadata["annotation_reply"] = { self._task_state.metadata.annotation_reply = AnnotationReply(
"id": annotation.id, id=annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, account=AnnotationReplyAccount(
} id=annotation.account_id,
name=account.name if account else "Dify user",
),
)
return annotation return annotation
@ -127,7 +132,7 @@ class MessageCycleManager:
:return: :return:
""" """
if self._application_generate_entity.app_config.additional_features.show_retrieve_source: if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources self._task_state.metadata.retriever_resources = event.retriever_resources
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
""" """

@ -1,8 +1,10 @@
import logging import logging
from collections.abc import Sequence
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
@ -85,7 +87,8 @@ class DatasetIndexToolCallbackHandler:
db.session.commit() db.session.commit()
def return_retriever_resource_info(self, resource: list): # TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER

@ -0,0 +1,23 @@
from typing import Any, Optional
from pydantic import BaseModel
class RetrievalSourceMetadata(BaseModel):
position: Optional[int] = None
dataset_id: Optional[str] = None
dataset_name: Optional[str] = None
document_id: Optional[str] = None
document_name: Optional[str] = None
data_source_type: Optional[str] = None
segment_id: Optional[str] = None
retriever_from: Optional[str] = None
score: Optional[float] = None
hit_count: Optional[int] = None
word_count: Optional[int] = None
segment_position: Optional[int] = None
index_node_hash: Optional[str] = None
content: Optional[str] = None
page: Optional[int] = None
doc_metadata: Optional[dict[str, Any]] = None
title: Optional[str] = None

@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
@ -198,21 +199,21 @@ class DatasetRetrieval:
dify_documents = [item for item in all_documents if item.provider == "dify"] dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"] external_documents = [item for item in all_documents if item.provider == "external"]
document_context_list = [] document_context_list: list[DocumentContext] = []
retrieval_resource_list = [] retrieval_resource_list: list[RetrievalSourceMetadata] = []
# deal with external documents # deal with external documents
for item in external_documents: for item in external_documents:
document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score")))
source = { source = RetrievalSourceMetadata(
"dataset_id": item.metadata.get("dataset_id"), dataset_id=item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), dataset_name=item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"), document_id=item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), document_name=item.metadata.get("title"),
"data_source_type": "external", data_source_type="external",
"retriever_from": invoke_from.to_source(), retriever_from=invoke_from.to_source(),
"score": item.metadata.get("score"), score=item.metadata.get("score"),
"content": item.page_content, content=item.page_content,
} )
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
@ -248,32 +249,32 @@ class DatasetRetrieval:
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, document_id=document.id,
"document_name": document.name, document_name=document.name,
"data_source_type": document.data_source_type, data_source_type=document.data_source_type,
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": invoke_from.to_source(), retriever_from=invoke_from.to_source(),
"score": record.score or 0.0, score=record.score or 0.0,
"doc_metadata": document.doc_metadata, doc_metadata=document.doc_metadata,
} )
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list: if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1): for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position item.position = position
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True)

@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.models.document import Document as RagDocument from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -107,7 +108,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
else: else:
document_context_list.append(segment.get_sign_content()) document_context_list.append(segment.get_sign_content())
if self.return_resource: if self.return_resource:
context_list = [] context_list: list[RetrievalSourceMetadata] = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
@ -121,28 +122,28 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"position": resource_number, position=resource_number,
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, document_id=document.id,
"document_name": document.name, document_name=document.name,
"data_source_type": document.data_source_type, data_source_type=document.data_source_type,
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None), score=document_score_list.get(segment.index_node_id, None),
"doc_metadata": document.doc_metadata, doc_metadata=document.doc_metadata,
} )
if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
context_list.append(source) context_list.append(source)
resource_number += 1 resource_number += 1

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
@ -14,7 +15,7 @@ from models.dataset import Dataset
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = { default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False, "reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else: else:
document_ids_filter = None document_ids_filter = None
if dataset.provider == "external": if dataset.provider == "external":
results = [] results: list[RetrievalDocument] = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document.metadata["dataset_name"] = dataset.name document.metadata["dataset_name"] = dataset.name
results.append(document) results.append(document)
# deal with external documents # deal with external documents
context_list = [] context_list: list[RetrievalSourceMetadata] = []
for position, item in enumerate(results, start=1): for position, item in enumerate(results, start=1):
if item.metadata is not None: if item.metadata is not None:
source = { source = RetrievalSourceMetadata(
"position": position, position=position,
"dataset_id": item.metadata.get("dataset_id"), dataset_id=item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"), dataset_name=item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"), document_id=item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"), document_name=item.metadata.get("title"),
"data_source_type": "external", data_source_type="external",
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": item.metadata.get("score"), score=item.metadata.get("score"),
"title": item.metadata.get("title"), title=item.metadata.get("title"),
"content": item.page_content, content=item.page_content,
} )
context_list.append(source) context_list.append(source)
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)
@ -125,7 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return "" return ""
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
retrieval_resource_list = [] retrieval_resource_list: list[RetrievalSourceMetadata] = []
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
@ -163,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for item in documents: for item in documents:
if item.metadata is not None and item.metadata.get("score"): if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = [] document_context_list: list[DocumentContext] = []
records = RetrievalService.format_retrieval_documents(documents) records = RetrievalService.format_retrieval_documents(documents)
if records: if records:
for record in records: for record in records:
@ -197,37 +198,37 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
.first() .first()
) )
if dataset and document: if dataset and document:
source = { source = RetrievalSourceMetadata(
"dataset_id": dataset.id, dataset_id=dataset.id,
"dataset_name": dataset.name, dataset_name=dataset.name,
"document_id": document.id, # type: ignore document_id=document.id, # type: ignore
"document_name": document.name, # type: ignore document_name=document.name, # type: ignore
"data_source_type": document.data_source_type, # type: ignore data_source_type=document.data_source_type, # type: ignore
"segment_id": segment.id, segment_id=segment.id,
"retriever_from": self.retriever_from, retriever_from=self.retriever_from,
"score": record.score or 0.0, score=record.score or 0.0,
"doc_metadata": document.doc_metadata, # type: ignore doc_metadata=document.doc_metadata, # type: ignore
} )
if self.retriever_from == "dev": if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count source.hit_count = segment.hit_count
source["word_count"] = segment.word_count source.word_count = segment.word_count
source["segment_position"] = segment.position source.segment_position = segment.position
source["index_node_hash"] = segment.index_node_hash source.index_node_hash = segment.index_node_hash
if segment.answer: if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else: else:
source["content"] = segment.content source.content = segment.content
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list: if self.return_resource and retrieval_resource_list:
retrieval_resource_list = sorted( retrieval_resource_list = sorted(
retrieval_resource_list, retrieval_resource_list,
key=lambda x: x.get("score") or 0.0, key=lambda x: x.score or 0.0,
reverse=True, reverse=True,
) )
for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore
item["position"] = position # type: ignore item.position = position # type: ignore
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(retrieval_resource_list) hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list: if document_context_list:

@ -1,9 +1,10 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from datetime import datetime from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.entities.node_entities import AgentNodeStrategyInit
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent):
class NodeRunRetrieverResourceEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources") retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")

@ -1,8 +1,10 @@
from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -17,7 +19,7 @@ class RunStreamChunkEvent(BaseModel):
class RunRetrieverResourceEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel):
retriever_resources: list[dict] = Field(..., description="retriever resources") retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context") context: str = Field(..., description="context")

@ -43,6 +43,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import ( from core.variables import (
ArrayAnySegment, ArrayAnySegment,
ArrayFileSegment, ArrayFileSegment,
@ -474,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment): elif isinstance(context_value_variable, ArraySegment):
context_str = "" context_str = ""
original_retriever_resource = [] original_retriever_resource: list[RetrievalSourceMetadata] = []
for item in context_value_variable.value: for item in context_value_variable.value:
if isinstance(item, str): if isinstance(item, str):
context_str += item + "\n" context_str += item + "\n"
@ -492,7 +493,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip() retriever_resources=original_retriever_resource, context=context_str.strip()
) )
def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: def _convert_to_original_retriever_resource(self, context_dict: dict):
if ( if (
"metadata" in context_dict "metadata" in context_dict
and "_source" in context_dict["metadata"] and "_source" in context_dict["metadata"]
@ -500,24 +501,24 @@ class LLMNode(BaseNode[LLMNodeData]):
): ):
metadata = context_dict.get("metadata", {}) metadata = context_dict.get("metadata", {})
source = { source = RetrievalSourceMetadata(
"position": metadata.get("position"), position=metadata.get("position"),
"dataset_id": metadata.get("dataset_id"), dataset_id=metadata.get("dataset_id"),
"dataset_name": metadata.get("dataset_name"), dataset_name=metadata.get("dataset_name"),
"document_id": metadata.get("document_id"), document_id=metadata.get("document_id"),
"document_name": metadata.get("document_name"), document_name=metadata.get("document_name"),
"data_source_type": metadata.get("data_source_type"), data_source_type=metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"), segment_id=metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"), retriever_from=metadata.get("retriever_from"),
"score": metadata.get("score"), score=metadata.get("score"),
"hit_count": metadata.get("segment_hit_count"), hit_count=metadata.get("segment_hit_count"),
"word_count": metadata.get("segment_word_count"), word_count=metadata.get("segment_word_count"),
"segment_position": metadata.get("segment_position"), segment_position=metadata.get("segment_position"),
"index_node_hash": metadata.get("segment_index_node_hash"), index_node_hash=metadata.get("segment_index_node_hash"),
"content": context_dict.get("content"), content=context_dict.get("content"),
"page": metadata.get("page"), page=metadata.get("page"),
"doc_metadata": metadata.get("doc_metadata"), doc_metadata=metadata.get("doc_metadata"),
} )
return source return source

Loading…
Cancel
Save