From 04763ce6b636a2c08e028b440eae2da3dfd3f851 Mon Sep 17 00:00:00 2001 From: IthacaDream Date: Mon, 5 May 2025 16:57:07 +0800 Subject: [PATCH] fix mypy checks --- api/core/agent/base_agent_runner.py | 2 +- api/core/rag/datasource/retrieval_service.py | 3 +-- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../dataset_retriever/dataset_retriever_tool.py | 13 ++++++------- api/core/tools/utils/dataset_retriever_tool.py | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index b19a52f7ce..6998e4d29a 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -92,7 +92,7 @@ class BaseAgentRunner(AppRunner): invoke_from=application_generate_entity.invoke_from, hit_callback=hit_callback, user_id=user_id, - inputs=application_generate_entity.inputs, + inputs=cast(dict, application_generate_entity.inputs), ) # get how many agent thoughts have been created self.agent_thought_count = ( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index ff645f3e37..01f74b4a22 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,6 @@ from flask import Flask, current_app from sqlalchemy.orm import load_only from configs import dify_config -from core.app.app_config.entities import MetadataFilteringCondition from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -126,7 +125,7 @@ class RetrievalService: dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None, - metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None, + metadata_filtering_conditions: Optional[dict] = None, ): dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3b4553f748..8f71d96bb0 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -650,7 +650,7 @@ class DatasetRetrieval: invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: Mapping[str, Any], + inputs: dict, ) -> Optional[list[DatasetRetrieverBaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 5d24a6826c..ed97b44f95 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,9 +1,8 @@ -from collections.abc import Mapping -from typing import Any, Optional +from typing import Any, Optional, cast from pydantic import BaseModel, Field -from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document as RetrievalDocument @@ -37,8 +36,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): description: str = "use this to retrieve a dataset. " dataset_id: str user_id: Optional[str] = None - retrieve_config: Optional[DatasetRetrieveConfigEntity] = None - inputs: Optional[Mapping[str, Any]] = None + retrieve_config: DatasetRetrieveConfigEntity + inputs: dict @classmethod def from_dataset(cls, dataset: Dataset, **kwargs): @@ -70,8 +69,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): query, self.tenant_id, self.user_id or "unknown", - self.retrieve_config.metadata_filtering_mode, - self.retrieve_config.metadata_model_config, + cast(str, self.retrieve_config.metadata_filtering_mode), + cast(ModelConfig, self.retrieve_config.metadata_model_config), self.retrieve_config.metadata_filtering_conditions, self.inputs, ) diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py index 5a1aca9e1e..ec0575f6c3 100644 --- a/api/core/tools/utils/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever_tool.py @@ -1,4 +1,4 @@ -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import Any, Optional from core.app.app_config.entities import DatasetRetrieveConfigEntity @@ -35,7 +35,7 @@ class DatasetRetrieverTool(Tool): invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler, user_id: str, - inputs: Mapping[str, Any], + inputs: dict, ) -> list["DatasetRetrieverTool"]: """ get dataset tool