@ -7,7 +7,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core . rag . datasource . keyword . keyword_factory import Keyword
from core . rag . datasource . vdb . vector_factory import Vector
from core . rag . rerank . constants . rerank_mode import RerankMode
from core . rag . retrieval . retri val_methods import RetrievalMethod
from core . rag . retrieval . retri e val_methods import RetrievalMethod
from extensions . ext_database import db
from models . dataset import Dataset
@ -26,7 +26,7 @@ default_retrieval_model = {
class RetrievalService :
@classmethod
def retrieve ( cls , retri val_method: str , dataset_id : str , query : str ,
def retrieve ( cls , retri e val_method: str , dataset_id : str , query : str ,
top_k : int , score_threshold : Optional [ float ] = .0 ,
reranking_model : Optional [ dict ] = None , reranking_mode : Optional [ str ] = ' reranking_model ' ,
weights : Optional [ dict ] = None ) :
@ -39,7 +39,7 @@ class RetrievalService:
threads = [ ]
exceptions = [ ]
# retrieval_model source with keyword
if retri val_method == ' keyword_search ' :
if retri e val_method == ' keyword_search ' :
keyword_thread = threading . Thread ( target = RetrievalService . keyword_search , kwargs = {
' flask_app ' : current_app . _get_current_object ( ) ,
' dataset_id ' : dataset_id ,
@ -51,7 +51,7 @@ class RetrievalService:
threads . append ( keyword_thread )
keyword_thread . start ( )
# retrieval_model source with semantic
if RetrievalMethod . is_support_semantic_search ( retri val_method) :
if RetrievalMethod . is_support_semantic_search ( retri e val_method) :
embedding_thread = threading . Thread ( target = RetrievalService . embedding_search , kwargs = {
' flask_app ' : current_app . _get_current_object ( ) ,
' dataset_id ' : dataset_id ,
@ -60,19 +60,19 @@ class RetrievalService:
' score_threshold ' : score_threshold ,
' reranking_model ' : reranking_model ,
' all_documents ' : all_documents ,
' retri val_method' : retri val_method,
' retri e val_method' : retri e val_method,
' exceptions ' : exceptions ,
} )
threads . append ( embedding_thread )
embedding_thread . start ( )
# retrieval source with full text
if RetrievalMethod . is_support_fulltext_search ( retri val_method) :
if RetrievalMethod . is_support_fulltext_search ( retri e val_method) :
full_text_index_thread = threading . Thread ( target = RetrievalService . full_text_index_search , kwargs = {
' flask_app ' : current_app . _get_current_object ( ) ,
' dataset_id ' : dataset_id ,
' query ' : query ,
' retri val_method' : retri val_method,
' retri e val_method' : retri e val_method,
' score_threshold ' : score_threshold ,
' top_k ' : top_k ,
' reranking_model ' : reranking_model ,
@ -89,7 +89,7 @@ class RetrievalService:
exception_message = ' ; \n ' . join ( exceptions )
raise Exception ( exception_message )
if retri val_method == RetrievalMethod . HYBRID_SEARCH . value :
if retri e val_method == RetrievalMethod . HYBRID_SEARCH . value :
data_post_processor = DataPostProcessor ( str ( dataset . tenant_id ) , reranking_mode ,
reranking_model , weights , False )
all_documents = data_post_processor . invoke (
@ -124,7 +124,7 @@ class RetrievalService:
@classmethod
def embedding_search ( cls , flask_app : Flask , dataset_id : str , query : str ,
top_k : int , score_threshold : Optional [ float ] , reranking_model : Optional [ dict ] ,
all_documents : list , retri val_method: str , exceptions : list ) :
all_documents : list , retri e val_method: str , exceptions : list ) :
with flask_app . app_context ( ) :
try :
dataset = db . session . query ( Dataset ) . filter (
@ -146,7 +146,7 @@ class RetrievalService:
)
if documents :
if reranking_model and reranking_model . get ( ' reranking_model_name ' ) and reranking_model . get ( ' reranking_provider_name ' ) and retri val_method == RetrievalMethod . SEMANTIC_SEARCH . value :
if reranking_model and reranking_model . get ( ' reranking_model_name ' ) and reranking_model . get ( ' reranking_provider_name ' ) and retri e val_method == RetrievalMethod . SEMANTIC_SEARCH . value :
data_post_processor = DataPostProcessor ( str ( dataset . tenant_id ) ,
RerankMode . RERANKING_MODEL . value ,
reranking_model , None , False )
@ -164,7 +164,7 @@ class RetrievalService:
@classmethod
def full_text_index_search ( cls , flask_app : Flask , dataset_id : str , query : str ,
top_k : int , score_threshold : Optional [ float ] , reranking_model : Optional [ dict ] ,
all_documents : list , retri val_method: str , exceptions : list ) :
all_documents : list , retri e val_method: str , exceptions : list ) :
with flask_app . app_context ( ) :
try :
dataset = db . session . query ( Dataset ) . filter (
@ -180,7 +180,7 @@ class RetrievalService:
top_k = top_k
)
if documents :
if reranking_model and reranking_model . get ( ' reranking_model_name ' ) and reranking_model . get ( ' reranking_provider_name ' ) and retri val_method == RetrievalMethod . FULL_TEXT_SEARCH . value :
if reranking_model and reranking_model . get ( ' reranking_model_name ' ) and reranking_model . get ( ' reranking_provider_name ' ) and retri e val_method == RetrievalMethod . FULL_TEXT_SEARCH . value :
data_post_processor = DataPostProcessor ( str ( dataset . tenant_id ) ,
RerankMode . RERANKING_MODEL . value ,
reranking_model , None , False )