fix mypy checks

pull/19208/head
IthacaDream 1 year ago
parent e58224bb7f
commit 04763ce6b6

@ -92,7 +92,7 @@ class BaseAgentRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
user_id=user_id,
inputs=application_generate_entity.inputs,
inputs=cast(dict, application_generate_entity.inputs),
)
# get how many agent thoughts have been created
self.agent_thought_count = (

@ -6,7 +6,6 @@ from flask import Flask, current_app
from sqlalchemy.orm import load_only
from configs import dify_config
from core.app.app_config.entities import MetadataFilteringCondition
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
@ -126,7 +125,7 @@ class RetrievalService:
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:

@ -650,7 +650,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: Mapping[str, Any],
inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset

@ -1,9 +1,8 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
@ -37,8 +36,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
description: str = "use this to retrieve a dataset. "
dataset_id: str
user_id: Optional[str] = None
retrieve_config: Optional[DatasetRetrieveConfigEntity] = None
inputs: Optional[Mapping[str, Any]] = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@ -70,8 +69,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
query,
self.tenant_id,
self.user_id or "unknown",
self.retrieve_config.metadata_filtering_mode,
self.retrieve_config.metadata_model_config,
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)

@ -1,4 +1,4 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -35,7 +35,7 @@ class DatasetRetrieverTool(Tool):
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: Mapping[str, Any],
inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool

Loading…
Cancel
Save