fix mypy checks

pull/19208/head
IthacaDream 1 year ago
parent e58224bb7f
commit 04763ce6b6

@ -92,7 +92,7 @@ class BaseAgentRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id, user_id=user_id,
inputs=application_generate_entity.inputs, inputs=cast(dict, application_generate_entity.inputs),
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (

@ -6,7 +6,6 @@ from flask import Flask, current_app
from sqlalchemy.orm import load_only from sqlalchemy.orm import load_only
from configs import dify_config from configs import dify_config
from core.app.app_config.entities import MetadataFilteringCondition
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
@ -126,7 +125,7 @@ class RetrievalService:
dataset_id: str, dataset_id: str,
query: str, query: str,
external_retrieval_model: Optional[dict] = None, external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None, metadata_filtering_conditions: Optional[dict] = None,
): ):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:

@ -650,7 +650,7 @@ class DatasetRetrieval:
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str, user_id: str,
inputs: Mapping[str, Any], inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]: ) -> Optional[list[DatasetRetrieverBaseTool]]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset

@ -1,9 +1,8 @@
from collections.abc import Mapping from typing import Any, Optional, cast
from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
@ -37,8 +36,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
description: str = "use this to retrieve a dataset. " description: str = "use this to retrieve a dataset. "
dataset_id: str dataset_id: str
user_id: Optional[str] = None user_id: Optional[str] = None
retrieve_config: Optional[DatasetRetrieveConfigEntity] = None retrieve_config: DatasetRetrieveConfigEntity
inputs: Optional[Mapping[str, Any]] = None inputs: dict
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
@ -70,8 +69,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
query, query,
self.tenant_id, self.tenant_id,
self.user_id or "unknown", self.user_id or "unknown",
self.retrieve_config.metadata_filtering_mode, cast(str, self.retrieve_config.metadata_filtering_mode),
self.retrieve_config.metadata_model_config, cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions, self.retrieve_config.metadata_filtering_conditions,
self.inputs, self.inputs,
) )

@ -1,4 +1,4 @@
from collections.abc import Generator, Mapping from collections.abc import Generator
from typing import Any, Optional from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -35,7 +35,7 @@ class DatasetRetrieverTool(Tool):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str, user_id: str,
inputs: Mapping[str, Any], inputs: dict,
) -> list["DatasetRetrieverTool"]: ) -> list["DatasetRetrieverTool"]:
""" """
get dataset tool get dataset tool

Loading…
Cancel
Save