|
|
|
|
@ -1,9 +1,8 @@
|
|
|
|
|
from collections.abc import Mapping
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
|
|
|
|
from core.rag.datasource.retrieval_service import RetrievalService
|
|
|
|
|
from core.rag.entities.context_entities import DocumentContext
|
|
|
|
|
from core.rag.models.document import Document as RetrievalDocument
|
|
|
|
|
@ -37,8 +36,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
description: str = "use this to retrieve a dataset. "
|
|
|
|
|
dataset_id: str
|
|
|
|
|
user_id: Optional[str] = None
|
|
|
|
|
retrieve_config: Optional[DatasetRetrieveConfigEntity] = None
|
|
|
|
|
inputs: Optional[Mapping[str, Any]] = None
|
|
|
|
|
retrieve_config: DatasetRetrieveConfigEntity
|
|
|
|
|
inputs: dict
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_dataset(cls, dataset: Dataset, **kwargs):
|
|
|
|
|
@ -70,8 +69,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|
|
|
|
query,
|
|
|
|
|
self.tenant_id,
|
|
|
|
|
self.user_id or "unknown",
|
|
|
|
|
self.retrieve_config.metadata_filtering_mode,
|
|
|
|
|
self.retrieve_config.metadata_model_config,
|
|
|
|
|
cast(str, self.retrieve_config.metadata_filtering_mode),
|
|
|
|
|
cast(ModelConfig, self.retrieve_config.metadata_model_config),
|
|
|
|
|
self.retrieve_config.metadata_filtering_conditions,
|
|
|
|
|
self.inputs,
|
|
|
|
|
)
|
|
|
|
|
|