Source code for components.retriever.reranker_retriever
"""Reranking model using modelclient as a retriever."""fromtypingimportList,Optional,Callable,Any,Dictimportloggingfromadalflow.core.retrieverimport(Retriever,)fromadalflow.core.typesimport(RetrieverStrQueriesType,RetrieverOutputType,RetrieverDocumentsType,RetrieverOutput,ModelType,)fromadalflow.core.model_clientimportModelClientlog=logging.getLogger(__name__)
[docs]classRerankerRetriever(Retriever[str,RetrieverStrQueriesType]):__doc__=r""" A retriever that uses a reranker model to rank the documents and retrieve the top-k documents. Args: top_k (int, optional): The number of top documents to retrieve. Defaults to 5. model_client (ModelClient): The model client that has a reranker model, such as ``CohereAPIClient`` or ``TransformersClient``. model_kwargs (Dict): The model kwargs to pass to the model client. documents (Optional[RetrieverDocumentsType], optional): The documents to build the index from. Defaults to None. document_map_func (Optional[Callable[[Any], str]], optional): The function to map the document of Any type to the specific type ``RetrieverDocumentType`` that the retriever expects. Defaults to None. Examples: """def__init__(self,model_client:ModelClient,# make sure you initialize the model client firstmodel_kwargs:Dict={},top_k:int=5,documents:Optional[RetrieverDocumentsType]=None,document_map_func:Optional[Callable[[Any],str]]=None,):super().__init__()self.top_k=top_kself._model_kwargs=model_kwargsor{}assert"model"inself._model_kwargs,"model must be specified in model_kwargs"ifnotisinstance(self._model_kwargs,Dict):raiseTypeError(f"{type(self).__name__} requires a dictionary for model_kwargs, not a string")ifnotisinstance(model_client,ModelClient):raiseTypeError(f"{type(self).__name__} requires a ModelClient instance for model_client")self.model_client=model_clientself.reset_index()ifdocuments:self.build_index_from_documents(documents,document_map_func)