parent
a4f37220a0
commit
4588831bff
@ -0,0 +1,158 @@
|
|||||||
|
import json
|
||||||
|
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
|
||||||
|
|
||||||
|
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||||
|
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from core.model_providers.models.entity.message import to_prompt_messages
|
||||||
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
|
from core.third_party.langchain.llms.fake import FakeLLM
|
||||||
|
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||||
|
"""
|
||||||
|
An Multi Dataset Retrieve Agent driven by Router.
|
||||||
|
"""
|
||||||
|
model_instance: BaseLLM
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_llm(cls, values: dict) -> dict:
|
||||||
|
return values
|
||||||
|
|
||||||
|
def should_use_agent(self, query: str):
|
||||||
|
"""
|
||||||
|
return should use agent
|
||||||
|
|
||||||
|
:param query:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
if len(self.tools) == 0:
|
||||||
|
return AgentFinish(return_values={"output": ''}, log='')
|
||||||
|
elif len(self.tools) == 1:
|
||||||
|
tool = next(iter(self.tools))
|
||||||
|
tool = cast(DatasetRetrieverTool, tool)
|
||||||
|
rst = tool.run(tool_input={'query': kwargs['input']})
|
||||||
|
# output = ''
|
||||||
|
# rst_json = json.loads(rst)
|
||||||
|
# for item in rst_json:
|
||||||
|
# output += f'{item["content"]}\n'
|
||||||
|
return AgentFinish(return_values={"output": rst}, log=rst)
|
||||||
|
|
||||||
|
if intermediate_steps:
|
||||||
|
_, observation = intermediate_steps[-1]
|
||||||
|
return AgentFinish(return_values={"output": observation}, log=observation)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
|
||||||
|
if isinstance(agent_decision, AgentAction):
|
||||||
|
tool_inputs = agent_decision.tool_input
|
||||||
|
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
|
||||||
|
tool_inputs['query'] = kwargs['input']
|
||||||
|
agent_decision.tool_input = tool_inputs
|
||||||
|
else:
|
||||||
|
agent_decision.return_values['output'] = ''
|
||||||
|
return agent_decision
|
||||||
|
except Exception as e:
|
||||||
|
new_exception = self.model_instance.handle_exceptions(e)
|
||||||
|
raise new_exception
|
||||||
|
|
||||||
|
def real_plan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
"""Given input, decided what to do.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||||
|
**kwargs: User inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Action specifying what tool to use.
|
||||||
|
"""
|
||||||
|
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||||
|
selected_inputs = {
|
||||||
|
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||||
|
}
|
||||||
|
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||||
|
prompt = self.prompt.format_prompt(**full_inputs)
|
||||||
|
messages = prompt.to_messages()
|
||||||
|
prompt_messages = to_prompt_messages(messages)
|
||||||
|
result = self.model_instance.run(
|
||||||
|
messages=prompt_messages,
|
||||||
|
functions=self.functions,
|
||||||
|
)
|
||||||
|
|
||||||
|
ai_message = AIMessage(
|
||||||
|
content=result.content,
|
||||||
|
additional_kwargs={
|
||||||
|
'function_call': result.function_call
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_decision = _parse_ai_message(ai_message)
|
||||||
|
return agent_decision
|
||||||
|
|
||||||
|
async def aplan(
|
||||||
|
self,
|
||||||
|
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm_and_tools(
|
||||||
|
cls,
|
||||||
|
model_instance: BaseLLM,
|
||||||
|
tools: Sequence[BaseTool],
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
|
||||||
|
system_message: Optional[SystemMessage] = SystemMessage(
|
||||||
|
content="You are a helpful AI assistant."
|
||||||
|
),
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> BaseSingleActionAgent:
|
||||||
|
prompt = cls.create_prompt(
|
||||||
|
extra_prompt_messages=extra_prompt_messages,
|
||||||
|
system_message=system_message,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
model_instance=model_instance,
|
||||||
|
llm=FakeLLM(response=''),
|
||||||
|
prompt=prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Optional, List
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.model_providers.models.entity.model_params import ModelType
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReranking(BaseProviderModel):
|
||||||
|
name: str
|
||||||
|
type: ModelType = ModelType.RERANKING
|
||||||
|
|
||||||
|
def __init__(self, model_provider: BaseModelProvider, client: Any, name: str):
|
||||||
|
super().__init__(model_provider, client)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_model_name(self) -> str:
|
||||||
|
"""
|
||||||
|
get base model name
|
||||||
|
|
||||||
|
:return: str
|
||||||
|
"""
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
raise NotImplementedError
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import cohere
|
||||||
|
import openai
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||||
|
LLMRateLimitError, LLMAuthorizationError
|
||||||
|
from core.model_providers.models.reranking.base import BaseReranking
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class CohereReranking(BaseReranking):
|
||||||
|
|
||||||
|
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||||
|
self.credentials = model_provider.get_model_credentials(
|
||||||
|
model_name=name,
|
||||||
|
model_type=self.type
|
||||||
|
)
|
||||||
|
|
||||||
|
client = cohere.Client(self.credentials.get('api_key'))
|
||||||
|
|
||||||
|
super().__init__(model_provider, client, name)
|
||||||
|
|
||||||
|
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]:
|
||||||
|
docs = []
|
||||||
|
doc_id = []
|
||||||
|
for document in documents:
|
||||||
|
if document.metadata['doc_id'] not in doc_id:
|
||||||
|
doc_id.append(document.metadata['doc_id'])
|
||||||
|
docs.append(document.page_content)
|
||||||
|
results = self.client.rerank(query=query, documents=docs, model=self.name, top_n=top_k)
|
||||||
|
rerank_documents = []
|
||||||
|
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
# format document
|
||||||
|
rerank_document = Document(
|
||||||
|
page_content=result.document['text'],
|
||||||
|
metadata={
|
||||||
|
"doc_id": documents[result.index].metadata['doc_id'],
|
||||||
|
"doc_hash": documents[result.index].metadata['doc_hash'],
|
||||||
|
"document_id": documents[result.index].metadata['document_id'],
|
||||||
|
"dataset_id": documents[result.index].metadata['dataset_id'],
|
||||||
|
'score': result.relevance_score
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# score threshold check
|
||||||
|
if score_threshold is not None:
|
||||||
|
if result.relevance_score >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
else:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
return rerank_documents
|
||||||
|
|
||||||
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
|
logging.warning("Invalid request to OpenAI API.")
|
||||||
|
return LLMBadRequestError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.APIConnectionError):
|
||||||
|
logging.warning("Failed to connect to OpenAI API.")
|
||||||
|
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||||
|
logging.warning("OpenAI service unavailable.")
|
||||||
|
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, openai.error.RateLimitError):
|
||||||
|
return LLMRateLimitError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.AuthenticationError):
|
||||||
|
return LLMAuthorizationError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.OpenAIError):
|
||||||
|
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
else:
|
||||||
|
return ex
|
||||||
@ -0,0 +1,152 @@
|
|||||||
|
import json
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.model_providers.models.base import BaseProviderModel
|
||||||
|
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
|
||||||
|
from core.model_providers.models.reranking.cohere_reranking import CohereReranking
|
||||||
|
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||||
|
from models.provider import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class CohereProvider(BaseModelProvider):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_name(self):
|
||||||
|
"""
|
||||||
|
Returns the name of a provider.
|
||||||
|
"""
|
||||||
|
return 'cohere'
|
||||||
|
|
||||||
|
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||||
|
return ModelMode.CHAT.value
|
||||||
|
|
||||||
|
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||||
|
if model_type == ModelType.RERANKING:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'id': 'rerank-english-v2.0',
|
||||||
|
'name': 'rerank-english-v2.0'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'rerank-multilingual-v2.0',
|
||||||
|
'name': 'rerank-multilingual-v2.0'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||||
|
"""
|
||||||
|
Returns the model class.
|
||||||
|
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if model_type == ModelType.RERANKING:
|
||||||
|
model_class = CohereReranking
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return model_class
|
||||||
|
|
||||||
|
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||||
|
"""
|
||||||
|
get model parameter rules.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return ModelKwargsRules(
|
||||||
|
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
|
||||||
|
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
|
||||||
|
presence_penalty=KwargRule[float](enabled=False),
|
||||||
|
frequency_penalty=KwargRule[float](enabled=False),
|
||||||
|
max_tokens=KwargRule[int](enabled=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||||
|
"""
|
||||||
|
Validates the given credentials.
|
||||||
|
"""
|
||||||
|
if 'api_key' not in credentials:
|
||||||
|
raise CredentialsValidateFailedError('Cohere api_key must be provided.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
credential_kwargs = {
|
||||||
|
'api_key': credentials['api_key'],
|
||||||
|
}
|
||||||
|
# todo validate
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||||
|
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||||
|
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||||
|
try:
|
||||||
|
credentials = json.loads(self.provider.encrypted_config)
|
||||||
|
except JSONDecodeError:
|
||||||
|
credentials = {
|
||||||
|
'api_key': None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if credentials['api_key']:
|
||||||
|
credentials['api_key'] = encrypter.decrypt_token(
|
||||||
|
self.provider.tenant_id,
|
||||||
|
credentials['api_key']
|
||||||
|
)
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||||
|
|
||||||
|
return credentials
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def should_deduct_quota(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||||
|
"""
|
||||||
|
check model credentials valid.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||||
|
credentials: dict) -> dict:
|
||||||
|
"""
|
||||||
|
encrypt model credentials for save.
|
||||||
|
|
||||||
|
:param tenant_id:
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param credentials:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
get credentials for llm use.
|
||||||
|
|
||||||
|
:param model_name:
|
||||||
|
:param model_type:
|
||||||
|
:param obfuscated:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.get_provider_credentials(obfuscated)
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"support_provider_types": [
|
||||||
|
"custom"
|
||||||
|
],
|
||||||
|
"system_config": null,
|
||||||
|
"model_flexibility": "fixed"
|
||||||
|
}
|
||||||
@ -0,0 +1,227 @@
|
|||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from typing import Type, Optional, List
|
||||||
|
|
||||||
|
from flask import current_app, Flask
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
|
from core.conversation_message_task import ConversationMessageTask
|
||||||
|
from core.embedding.cached_embedding import CacheEmbedding
|
||||||
|
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||||
|
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, DocumentSegment, Document
|
||||||
|
from services.retrieval_service import RetrievalService
|
||||||
|
|
||||||
|
default_retrieval_model = {
|
||||||
|
'search_method': 'semantic_search',
|
||||||
|
'reranking_enable': False,
|
||||||
|
'reranking_model': {
|
||||||
|
'reranking_provider_name': '',
|
||||||
|
'reranking_model_name': ''
|
||||||
|
},
|
||||||
|
'top_k': 2,
|
||||||
|
'score_threshold_enable': False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMultiRetrieverToolInput(BaseModel):
|
||||||
|
query: str = Field(..., description="dataset multi retriever and rerank")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMultiRetrieverTool(BaseTool):
|
||||||
|
"""Tool for querying multi dataset."""
|
||||||
|
name: str = "dataset-"
|
||||||
|
args_schema: Type[BaseModel] = DatasetMultiRetrieverToolInput
|
||||||
|
description: str = "dataset multi retriever and rerank. "
|
||||||
|
tenant_id: str
|
||||||
|
dataset_ids: List[str]
|
||||||
|
top_k: int = 2
|
||||||
|
score_threshold: Optional[float] = None
|
||||||
|
reranking_provider_name: str
|
||||||
|
reranking_model_name: str
|
||||||
|
conversation_message_task: ConversationMessageTask
|
||||||
|
return_resource: bool
|
||||||
|
retriever_from: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dataset(cls, dataset_ids: List[str], tenant_id: str, **kwargs):
|
||||||
|
return cls(
|
||||||
|
name=f'dataset-{tenant_id}',
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_ids=dataset_ids,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
threads = []
|
||||||
|
all_documents = []
|
||||||
|
for dataset_id in self.dataset_ids:
|
||||||
|
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'dataset_id': dataset_id,
|
||||||
|
'query': query,
|
||||||
|
'all_documents': all_documents
|
||||||
|
})
|
||||||
|
threads.append(retrieval_thread)
|
||||||
|
retrieval_thread.start()
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
# do rerank for searched documents
|
||||||
|
rerank = ModelFactory.get_reranking_model(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
model_provider_name=self.reranking_provider_name,
|
||||||
|
model_name=self.reranking_model_name
|
||||||
|
)
|
||||||
|
all_documents = rerank.rerank(query, all_documents, self.score_threshold, self.top_k)
|
||||||
|
|
||||||
|
hit_callback = DatasetIndexToolCallbackHandler(self.conversation_message_task)
|
||||||
|
hit_callback.on_tool_end(all_documents)
|
||||||
|
|
||||||
|
document_context_list = []
|
||||||
|
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
|
||||||
|
segments = DocumentSegment.query.filter(
|
||||||
|
DocumentSegment.completed_at.isnot(None),
|
||||||
|
DocumentSegment.status == 'completed',
|
||||||
|
DocumentSegment.enabled == True,
|
||||||
|
DocumentSegment.index_node_id.in_(index_node_ids)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
if segments:
|
||||||
|
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||||
|
sorted_segments = sorted(segments,
|
||||||
|
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
|
||||||
|
float('inf')))
|
||||||
|
for segment in sorted_segments:
|
||||||
|
if segment.answer:
|
||||||
|
document_context_list.append(f'question:{segment.content} answer:{segment.answer}')
|
||||||
|
else:
|
||||||
|
document_context_list.append(segment.content)
|
||||||
|
if self.return_resource:
|
||||||
|
context_list = []
|
||||||
|
resource_number = 1
|
||||||
|
for segment in sorted_segments:
|
||||||
|
dataset = Dataset.query.filter_by(
|
||||||
|
id=segment.dataset_id
|
||||||
|
).first()
|
||||||
|
document = Document.query.filter(Document.id == segment.document_id,
|
||||||
|
Document.enabled == True,
|
||||||
|
Document.archived == False,
|
||||||
|
).first()
|
||||||
|
if dataset and document:
|
||||||
|
source = {
|
||||||
|
'position': resource_number,
|
||||||
|
'dataset_id': dataset.id,
|
||||||
|
'dataset_name': dataset.name,
|
||||||
|
'document_id': document.id,
|
||||||
|
'document_name': document.name,
|
||||||
|
'data_source_type': document.data_source_type,
|
||||||
|
'segment_id': segment.id,
|
||||||
|
'retriever_from': self.retriever_from
|
||||||
|
}
|
||||||
|
if self.retriever_from == 'dev':
|
||||||
|
source['hit_count'] = segment.hit_count
|
||||||
|
source['word_count'] = segment.word_count
|
||||||
|
source['segment_position'] = segment.position
|
||||||
|
source['index_node_hash'] = segment.index_node_hash
|
||||||
|
if segment.answer:
|
||||||
|
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
|
||||||
|
else:
|
||||||
|
source['content'] = segment.content
|
||||||
|
context_list.append(source)
|
||||||
|
resource_number += 1
|
||||||
|
hit_callback.return_retriever_resource_info(context_list)
|
||||||
|
|
||||||
|
return str("\n".join(document_context_list))
|
||||||
|
|
||||||
|
async def _arun(self, tool_input: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_documents: List):
|
||||||
|
with flask_app.app_context():
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.tenant_id == self.tenant_id,
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
return []
|
||||||
|
# get retrieval model , if the model is not setting , using default
|
||||||
|
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||||
|
|
||||||
|
if dataset.indexing_technique == "economy":
|
||||||
|
# use keyword table query
|
||||||
|
kw_table_index = KeywordTableIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=KeywordTableConfig(
|
||||||
|
max_keywords_per_chunk=5
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
else:
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding_model = ModelFactory.get_embedding_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=dataset.embedding_model_provider,
|
||||||
|
model_name=dataset.embedding_model
|
||||||
|
)
|
||||||
|
except LLMBadRequestError:
|
||||||
|
return []
|
||||||
|
except ProviderTokenNotInitError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
embeddings = CacheEmbedding(embedding_model)
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
threads = []
|
||||||
|
if self.top_k > 0:
|
||||||
|
# retrieval_model source with semantic
|
||||||
|
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model[
|
||||||
|
'search_method'] == 'hybrid_search':
|
||||||
|
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'dataset': dataset,
|
||||||
|
'query': query,
|
||||||
|
'top_k': self.top_k,
|
||||||
|
'score_threshold': self.score_threshold,
|
||||||
|
'reranking_model': None,
|
||||||
|
'all_documents': documents,
|
||||||
|
'search_method': 'hybrid_search',
|
||||||
|
'embeddings': embeddings
|
||||||
|
})
|
||||||
|
threads.append(embedding_thread)
|
||||||
|
embedding_thread.start()
|
||||||
|
|
||||||
|
# retrieval_model source with full text
|
||||||
|
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model[
|
||||||
|
'search_method'] == 'hybrid_search':
|
||||||
|
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||||
|
kwargs={
|
||||||
|
'flask_app': current_app._get_current_object(),
|
||||||
|
'dataset': dataset,
|
||||||
|
'query': query,
|
||||||
|
'search_method': 'hybrid_search',
|
||||||
|
'embeddings': embeddings,
|
||||||
|
'score_threshold': retrieval_model[
|
||||||
|
'score_threshold'] if retrieval_model[
|
||||||
|
'score_threshold_enable'] else None,
|
||||||
|
'top_k': self.top_k,
|
||||||
|
'reranking_model': retrieval_model[
|
||||||
|
'reranking_model'] if retrieval_model[
|
||||||
|
'reranking_enable'] else None,
|
||||||
|
'all_documents': documents
|
||||||
|
})
|
||||||
|
threads.append(full_text_index_thread)
|
||||||
|
full_text_index_thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
all_documents.extend(documents)
|
||||||
@ -0,0 +1,505 @@
|
|||||||
|
"""Wrapper around weaviate vector database."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
|
||||||
|
def _default_schema(index_name: str) -> Dict:
|
||||||
|
return {
|
||||||
|
"class": index_name,
|
||||||
|
"properties": [
|
||||||
|
{
|
||||||
|
"name": "text",
|
||||||
|
"dataType": ["text"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_weaviate_client(**kwargs: Any) -> Any:
|
||||||
|
client = kwargs.get("client")
|
||||||
|
if client is not None:
|
||||||
|
return client
|
||||||
|
|
||||||
|
weaviate_url = get_from_dict_or_env(kwargs, "weaviate_url", "WEAVIATE_URL")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# the weaviate api key param should not be mandatory
|
||||||
|
weaviate_api_key = get_from_dict_or_env(
|
||||||
|
kwargs, "weaviate_api_key", "WEAVIATE_API_KEY", None
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
weaviate_api_key = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import weaviate
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import weaviate python package. "
|
||||||
|
"Please install it with `pip install weaviate-client`"
|
||||||
|
)
|
||||||
|
|
||||||
|
auth = (
|
||||||
|
weaviate.auth.AuthApiKey(api_key=weaviate_api_key)
|
||||||
|
if weaviate_api_key is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
client = weaviate.Client(weaviate_url, auth_client_secret=auth)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _default_score_normalizer(val: float) -> float:
|
||||||
|
return 1 - 1 / (1 + np.exp(val))
|
||||||
|
|
||||||
|
|
||||||
|
def _json_serializable(value: Any) -> Any:
|
||||||
|
if isinstance(value, datetime.datetime):
|
||||||
|
return value.isoformat()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class Weaviate(VectorStore):
|
||||||
|
"""Wrapper around Weaviate vector database.
|
||||||
|
|
||||||
|
To use, you should have the ``weaviate-client`` python package installed.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import weaviate
|
||||||
|
from langchain.vectorstores import Weaviate
|
||||||
|
client = weaviate.Client(url=os.environ["WEAVIATE_URL"], ...)
|
||||||
|
weaviate = Weaviate(client, index_name, text_key)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: Any,
|
||||||
|
index_name: str,
|
||||||
|
text_key: str,
|
||||||
|
embedding: Optional[Embeddings] = None,
|
||||||
|
attributes: Optional[List[str]] = None,
|
||||||
|
relevance_score_fn: Optional[
|
||||||
|
Callable[[float], float]
|
||||||
|
] = _default_score_normalizer,
|
||||||
|
by_text: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize with Weaviate client."""
|
||||||
|
try:
|
||||||
|
import weaviate
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import weaviate python package. "
|
||||||
|
"Please install it with `pip install weaviate-client`."
|
||||||
|
)
|
||||||
|
if not isinstance(client, weaviate.Client):
|
||||||
|
raise ValueError(
|
||||||
|
f"client should be an instance of weaviate.Client, got {type(client)}"
|
||||||
|
)
|
||||||
|
self._client = client
|
||||||
|
self._index_name = index_name
|
||||||
|
self._embedding = embedding
|
||||||
|
self._text_key = text_key
|
||||||
|
self._query_attrs = [self._text_key]
|
||||||
|
self.relevance_score_fn = relevance_score_fn
|
||||||
|
self._by_text = by_text
|
||||||
|
if attributes is not None:
|
||||||
|
self._query_attrs.extend(attributes)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
|
return self._embedding
|
||||||
|
|
||||||
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
|
return (
|
||||||
|
self.relevance_score_fn
|
||||||
|
if self.relevance_score_fn
|
||||||
|
else _default_score_normalizer
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Upload texts with metadata (properties) to Weaviate."""
|
||||||
|
from weaviate.util import get_valid_uuid
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
embeddings: Optional[List[List[float]]] = None
|
||||||
|
if self._embedding:
|
||||||
|
if not isinstance(texts, list):
|
||||||
|
texts = list(texts)
|
||||||
|
embeddings = self._embedding.embed_documents(texts)
|
||||||
|
|
||||||
|
with self._client.batch as batch:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
data_properties = {self._text_key: text}
|
||||||
|
if metadatas is not None:
|
||||||
|
for key, val in metadatas[i].items():
|
||||||
|
data_properties[key] = _json_serializable(val)
|
||||||
|
|
||||||
|
# Allow for ids (consistent w/ other methods)
|
||||||
|
# # Or uuids (backwards compatble w/ existing arg)
|
||||||
|
# If the UUID of one of the objects already exists
|
||||||
|
# then the existing object will be replaced by the new object.
|
||||||
|
_id = get_valid_uuid(uuid4())
|
||||||
|
if "uuids" in kwargs:
|
||||||
|
_id = kwargs["uuids"][i]
|
||||||
|
elif "ids" in kwargs:
|
||||||
|
_id = kwargs["ids"][i]
|
||||||
|
|
||||||
|
batch.add_data_object(
|
||||||
|
data_object=data_properties,
|
||||||
|
class_name=self._index_name,
|
||||||
|
uuid=_id,
|
||||||
|
vector=embeddings[i] if embeddings else None,
|
||||||
|
)
|
||||||
|
ids.append(_id)
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
if self._by_text:
|
||||||
|
return self.similarity_search_by_text(query, k, **kwargs)
|
||||||
|
else:
|
||||||
|
if self._embedding is None:
|
||||||
|
raise ValueError(
|
||||||
|
"_embedding cannot be None for similarity_search when "
|
||||||
|
"_by_text=False"
|
||||||
|
)
|
||||||
|
embedding = self._embedding.embed_query(query)
|
||||||
|
return self.similarity_search_by_vector(embedding, k, **kwargs)
|
||||||
|
|
||||||
|
def similarity_search_by_text(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
content: Dict[str, Any] = {"concepts": [query]}
|
||||||
|
if kwargs.get("search_distance"):
|
||||||
|
content["certainty"] = kwargs.get("search_distance")
|
||||||
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
if kwargs.get("where_filter"):
|
||||||
|
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||||
|
if kwargs.get("additional"):
|
||||||
|
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||||
|
result = query_obj.with_near_text(content).with_limit(k).do()
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
docs = []
|
||||||
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
docs.append(Document(page_content=text, metadata=res))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_by_bm25(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs using BM25F.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents most similar to the query.
|
||||||
|
"""
|
||||||
|
content: Dict[str, Any] = {"concepts": [query]}
|
||||||
|
if kwargs.get("search_distance"):
|
||||||
|
content["certainty"] = kwargs.get("search_distance")
|
||||||
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
if kwargs.get("where_filter"):
|
||||||
|
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||||
|
if kwargs.get("additional"):
|
||||||
|
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||||
|
result = query_obj.with_bm25(query=content).with_limit(k).do()
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
docs = []
|
||||||
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
docs.append(Document(page_content=text, metadata=res))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Look up similar documents by embedding vector in Weaviate."""
|
||||||
|
vector = {"vector": embedding}
|
||||||
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
if kwargs.get("where_filter"):
|
||||||
|
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||||
|
if kwargs.get("additional"):
|
||||||
|
query_obj = query_obj.with_additional(kwargs.get("additional"))
|
||||||
|
result = query_obj.with_near_vector(vector).with_limit(k).do()
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
docs = []
|
||||||
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
docs.append(Document(page_content=text, metadata=res))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if self._embedding is not None:
|
||||||
|
embedding = self._embedding.embed_query(query)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"max_marginal_relevance_search requires a suitable Embeddings object"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
vector = {"vector": embedding}
|
||||||
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
if kwargs.get("where_filter"):
|
||||||
|
query_obj = query_obj.with_where(kwargs.get("where_filter"))
|
||||||
|
results = (
|
||||||
|
query_obj.with_additional("vector")
|
||||||
|
.with_near_vector(vector)
|
||||||
|
.with_limit(fetch_k)
|
||||||
|
.do()
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = results["data"]["Get"][self._index_name]
|
||||||
|
embeddings = [result["_additional"]["vector"] for result in payload]
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
for idx in mmr_selected:
|
||||||
|
text = payload[idx].pop(self._text_key)
|
||||||
|
payload[idx].pop("_additional")
|
||||||
|
meta = payload[idx]
|
||||||
|
docs.append(Document(page_content=text, metadata=meta))
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""
|
||||||
|
Return list of documents most similar to the query
|
||||||
|
text and cosine distance in float for each.
|
||||||
|
Lower score represents more similarity.
|
||||||
|
"""
|
||||||
|
if self._embedding is None:
|
||||||
|
raise ValueError(
|
||||||
|
"_embedding cannot be None for similarity_search_with_score"
|
||||||
|
)
|
||||||
|
content: Dict[str, Any] = {"concepts": [query]}
|
||||||
|
if kwargs.get("search_distance"):
|
||||||
|
content["certainty"] = kwargs.get("search_distance")
|
||||||
|
query_obj = self._client.query.get(self._index_name, self._query_attrs)
|
||||||
|
|
||||||
|
embedded_query = self._embedding.embed_query(query)
|
||||||
|
if not self._by_text:
|
||||||
|
vector = {"vector": embedded_query}
|
||||||
|
result = (
|
||||||
|
query_obj.with_near_vector(vector)
|
||||||
|
.with_limit(k)
|
||||||
|
.with_additional("vector")
|
||||||
|
.do()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = (
|
||||||
|
query_obj.with_near_text(content)
|
||||||
|
.with_limit(k)
|
||||||
|
.with_additional("vector")
|
||||||
|
.do()
|
||||||
|
)
|
||||||
|
|
||||||
|
if "errors" in result:
|
||||||
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
|
|
||||||
|
docs_and_scores = []
|
||||||
|
for res in result["data"]["Get"][self._index_name]:
|
||||||
|
text = res.pop(self._text_key)
|
||||||
|
score = np.dot(res["_additional"]["vector"], embedded_query)
|
||||||
|
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||||
|
return docs_and_scores
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type[Weaviate],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Weaviate:
|
||||||
|
"""Construct Weaviate wrapper from raw documents.
|
||||||
|
|
||||||
|
This is a user-friendly interface that:
|
||||||
|
1. Embeds documents.
|
||||||
|
2. Creates a new index for the embeddings in the Weaviate instance.
|
||||||
|
3. Adds the documents to the newly created Weaviate index.
|
||||||
|
|
||||||
|
This is intended to be a quick way to get started.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.vectorstores.weaviate import Weaviate
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
weaviate = Weaviate.from_texts(
|
||||||
|
texts,
|
||||||
|
embeddings,
|
||||||
|
weaviate_url="http://localhost:8080"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
client = _create_weaviate_client(**kwargs)
|
||||||
|
|
||||||
|
from weaviate.util import get_valid_uuid
|
||||||
|
|
||||||
|
index_name = kwargs.get("index_name", f"LangChain_{uuid4().hex}")
|
||||||
|
embeddings = embedding.embed_documents(texts) if embedding else None
|
||||||
|
text_key = "text"
|
||||||
|
schema = _default_schema(index_name)
|
||||||
|
attributes = list(metadatas[0].keys()) if metadatas else None
|
||||||
|
|
||||||
|
# check whether the index already exists
|
||||||
|
if not client.schema.contains(schema):
|
||||||
|
client.schema.create_class(schema)
|
||||||
|
|
||||||
|
with client.batch as batch:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
data_properties = {
|
||||||
|
text_key: text,
|
||||||
|
}
|
||||||
|
if metadatas is not None:
|
||||||
|
for key in metadatas[i].keys():
|
||||||
|
data_properties[key] = metadatas[i][key]
|
||||||
|
|
||||||
|
# If the UUID of one of the objects already exists
|
||||||
|
# then the existing objectwill be replaced by the new object.
|
||||||
|
if "uuids" in kwargs:
|
||||||
|
_id = kwargs["uuids"][i]
|
||||||
|
else:
|
||||||
|
_id = get_valid_uuid(uuid4())
|
||||||
|
|
||||||
|
# if an embedding strategy is not provided, we let
|
||||||
|
# weaviate create the embedding. Note that this will only
|
||||||
|
# work if weaviate has been installed with a vectorizer module
|
||||||
|
# like text2vec-contextionary for example
|
||||||
|
params = {
|
||||||
|
"uuid": _id,
|
||||||
|
"data_object": data_properties,
|
||||||
|
"class_name": index_name,
|
||||||
|
}
|
||||||
|
if embeddings is not None:
|
||||||
|
params["vector"] = embeddings[i]
|
||||||
|
|
||||||
|
batch.add_data_object(**params)
|
||||||
|
|
||||||
|
batch.flush()
|
||||||
|
|
||||||
|
relevance_score_fn = kwargs.get("relevance_score_fn")
|
||||||
|
by_text: bool = kwargs.get("by_text", False)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
client,
|
||||||
|
index_name,
|
||||||
|
text_key,
|
||||||
|
embedding=embedding,
|
||||||
|
attributes=attributes,
|
||||||
|
relevance_score_fn=relevance_score_fn,
|
||||||
|
by_text=by_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||||
|
"""Delete by vector IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids to delete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
raise ValueError("No ids provided to delete.")
|
||||||
|
|
||||||
|
# TODO: Check if this can be done in bulk
|
||||||
|
for id in ids:
|
||||||
|
self._client.data_object.delete(uuid=id)
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
"""add-dataset-retrival-model
|
||||||
|
|
||||||
|
Revision ID: fca025d3b60f
|
||||||
|
Revises: b3a09c049e8e
|
||||||
|
Create Date: 2023-11-03 13:08:23.246396
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'fca025d3b60f'
|
||||||
|
down_revision = '8fe468ba0ca5'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('sessions')
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('retrieval_model', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
batch_op.create_index('retrieval_model_idx', ['retrieval_model'], unique=False, postgresql_using='gin')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('retrieval_model_idx', postgresql_using='gin')
|
||||||
|
batch_op.drop_column('retrieval_model')
|
||||||
|
|
||||||
|
op.create_table('sessions',
|
||||||
|
sa.Column('id', sa.INTEGER(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('session_id', sa.VARCHAR(length=255), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('data', postgresql.BYTEA(), autoincrement=False, nullable=True),
|
||||||
|
sa.Column('expiry', postgresql.TIMESTAMP(), autoincrement=False, nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='sessions_pkey'),
|
||||||
|
sa.UniqueConstraint('session_id', name='sessions_session_id_key')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from flask import current_app, Flask
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
default_retrieval_model = {
|
||||||
|
'search_method': 'semantic_search',
|
||||||
|
'reranking_enable': False,
|
||||||
|
'reranking_model': {
|
||||||
|
'reranking_provider_name': '',
|
||||||
|
'reranking_model_name': ''
|
||||||
|
},
|
||||||
|
'top_k': 2,
|
||||||
|
'score_threshold_enable': False
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalService:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||||
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||||
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||||
|
with flask_app.app_context():
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = vector_index.search(
|
||||||
|
query,
|
||||||
|
search_type='similarity_score_threshold',
|
||||||
|
search_kwargs={
|
||||||
|
'k': top_k,
|
||||||
|
'score_threshold': score_threshold,
|
||||||
|
'filter': {
|
||||||
|
'group_id': [dataset.id]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
if reranking_model and search_method == 'semantic_search':
|
||||||
|
rerank = ModelFactory.get_reranking_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=reranking_model['reranking_provider_name'],
|
||||||
|
model_name=reranking_model['reranking_model_name']
|
||||||
|
)
|
||||||
|
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||||
|
else:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||||
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||||
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||||
|
with flask_app.app_context():
|
||||||
|
|
||||||
|
vector_index = VectorIndex(
|
||||||
|
dataset=dataset,
|
||||||
|
config=current_app.config,
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
documents = vector_index.search_by_full_text_index(
|
||||||
|
query,
|
||||||
|
search_type='similarity_score_threshold',
|
||||||
|
top_k=top_k
|
||||||
|
)
|
||||||
|
if documents:
|
||||||
|
if reranking_model and search_method == 'full_text_search':
|
||||||
|
rerank = ModelFactory.get_reranking_model(
|
||||||
|
tenant_id=dataset.tenant_id,
|
||||||
|
model_provider_name=reranking_model['reranking_provider_name'],
|
||||||
|
model_name=reranking_model['reranking_model_name']
|
||||||
|
)
|
||||||
|
all_documents.extend(rerank.rerank(query, documents, score_threshold, len(documents)))
|
||||||
|
else:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue