Source code for components.retriever.reranker_retriever

"""Reranking model using modelclient as a retriever."""

from typing import List, Optional, Callable, Any, Dict
import logging

from adalflow.core.retriever import (
    Retriever,
)
from adalflow.core.types import (
    RetrieverStrQueriesType,
    RetrieverOutputType,
    RetrieverDocumentsType,
    RetrieverOutput,
    ModelType,
)
from adalflow.core.model_client import ModelClient

log = logging.getLogger(__name__)


[docs] class RerankerRetriever(Retriever[str, RetrieverStrQueriesType]): __doc__ = r""" A retriever that uses a reranker model to rank the documents and retrieve the top-k documents. Args: top_k (int, optional): The number of top documents to retrieve. Defaults to 5. model_client (ModelClient): The model client that has a reranker model, such as ``CohereAPIClient`` or ``TransformersClient``. model_kwargs (Dict): The model kwargs to pass to the model client. documents (Optional[RetrieverDocumentsType], optional): The documents to build the index from. Defaults to None. document_map_func (Optional[Callable[[Any], str]], optional): The function to map the document of Any type to the specific type ``RetrieverDocumentType`` that the retriever expects. Defaults to None. Examples: """ def __init__( self, model_client: ModelClient, # make sure you initialize the model client first model_kwargs: Dict = {}, top_k: int = 5, documents: Optional[RetrieverDocumentsType] = None, document_map_func: Optional[Callable[[Any], str]] = None, ): super().__init__() self.top_k = top_k self._model_kwargs = model_kwargs or {} assert "model" in self._model_kwargs, "model must be specified in model_kwargs" if not isinstance(self._model_kwargs, Dict): raise TypeError( f"{type(self).__name__} requires a dictionary for model_kwargs, not a string" ) if not isinstance(model_client, ModelClient): raise TypeError( f"{type(self).__name__} requires a ModelClient instance for model_client" ) self.model_client = model_client self.reset_index() if documents: self.build_index_from_documents(documents, document_map_func)
[docs] def reset_index(self): self.indexed = False self.documents = [] self.total_documents: int = 0
[docs] def build_index_from_documents( self, documents: RetrieverDocumentsType, document_map_func: Optional[Callable[[Any], str]] = None, ): if document_map_func: documents = [document_map_func(doc) for doc in documents] else: documents = documents self.total_documents = len(documents) self._model_kwargs["documents"] = documents self.indexed = True
[docs] def call( self, input: RetrieverStrQueriesType, top_k: Optional[int] = None ) -> RetrieverOutputType: top_k = top_k or self.top_k queries = input if isinstance(input, List) else [input] retrieved_outputs: RetrieverOutputType = [] model_kwargs = self._model_kwargs.copy() model_kwargs["top_k"] = top_k for query in queries: api_kwargs = self.model_client.convert_inputs_to_api_kwargs( input=query, model_kwargs=model_kwargs, model_type=ModelType.RERANKER, ) log.info(f"api_kwargs: {api_kwargs}") top_k_indices, top_k_scores = self.model_client.call( api_kwargs=api_kwargs, model_type=ModelType.RERANKER ) retrieved_outputs.append( RetrieverOutput( doc_indices=top_k_indices, doc_scores=top_k_scores, query=query, ) ) return retrieved_outputs
def _extra_repr(self) -> str: s = f"top_k={self.top_k}, model_kwargs={self._model_kwargs}, model_client={self.model_client}, total_documents={self.total_documents}" return s