refactor: move the embedding to the rag module and abstract the rerank runner for extension (#9423)
parent
e7aecb89dd
commit
b90ad587c2
@ -0,0 +1,26 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.rag.models.document import Document
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRerankRunner(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
documents: list[Document],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Run rerank model
|
||||||
|
:param query: search query
|
||||||
|
:param documents: documents for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n
|
||||||
|
:param user: unique user id if needed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
||||||
|
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||||
|
from core.rag.rerank.rerank_type import RerankMode
|
||||||
|
from core.rag.rerank.weight_rerank import WeightRerankRunner
|
||||||
|
|
||||||
|
|
||||||
|
class RerankRunnerFactory:
|
||||||
|
@staticmethod
|
||||||
|
def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner:
|
||||||
|
match runner_type:
|
||||||
|
case RerankMode.RERANKING_MODEL.value:
|
||||||
|
return RerankModelRunner(*args, **kwargs)
|
||||||
|
case RerankMode.WEIGHTED_SCORE.value:
|
||||||
|
return WeightRerankRunner(*args, **kwargs)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown runner type: {runner_type}")
|
||||||
Loading…
Reference in New Issue