Source code for components.retriever.qdrant_retriever

"""Leverage a Qdrant collection to retrieve documents."""

from typing import List, Optional, Any
from qdrant_client import QdrantClient, models

from adalflow.core.retriever import (
    Retriever,
)
from adalflow.core.embedder import Embedder

from adalflow.core.types import (
    RetrieverOutput,
    RetrieverStrQueryType,
    RetrieverStrQueriesType,
    Document,
)


[docs] class QdrantRetriever(Retriever[Any, RetrieverStrQueryType]): __doc__ = r"""Use a Qdrant collection to retrieve documents. Args: collection_name (str): the collection name in Qdrant. client (QdrantClient): An instance of qdrant_client.QdrantClient. embedder (Embedder): An instance of Embedder. top_k (Optional[int], optional): top k documents to fetch. Defaults to 10. vector_name (Optional[str], optional): the name of the vector in the collection. Defaults to None. text_key (str, optional): the key in the payload that contains the text. Defaults to "text". metadata_key (str, optional): the key in the payload that contains the metadata. Defaults to "meta_data". filter (Optional[models.Filter], optional): the filter to apply to the query. Defaults to None. References: [1] Qdrant: https://qdrant.tech/ [2] Documentation: https://qdrant.tech/documentation/ """ def __init__( self, collection_name: str, client: QdrantClient, embedder: Embedder, top_k: Optional[int] = 10, vector_name: Optional[str] = None, text_key: str = "text", metadata_key: str = "meta_data", filter: Optional[models.Filter] = None, ): super().__init__() self._top_k = top_k self._collection_name = collection_name self._client = client self._embedder = embedder self._text_key = text_key self._metadata_key = metadata_key self._filter = filter self._vector_name = vector_name or self._get_first_vector_name()
[docs] def reset_index(self): if self._client.collection_exists(self._collection_name): self._client.delete_collection(self._collection_name)
[docs] def call( self, input: RetrieverStrQueriesType, top_k: Optional[int] = None, **kwargs, ) -> List[RetrieverOutput]: top_k = top_k or self._top_k queries: List[str] = input if isinstance(input, list) else [input] queries_embeddings = self._embedder(queries) query_requests: List[models.QueryRequest] = [] for idx, query in enumerate(queries): query_embedding = queries_embeddings.data[idx].embedding query_requests.append( models.QueryRequest( query=query_embedding, limit=top_k, using=self._vector_name, with_payload=True, with_vector=True, filter=self._filter, **kwargs, ) ) results = self._client.query_batch_points( self._collection_name, requests=query_requests ) retrieved_outputs: List[RetrieverOutput] = [] for result in results: out = self._points_to_output( result.points, query, self._text_key, self._metadata_key, self._vector_name, ) retrieved_outputs.append(out) return retrieved_outputs
def _get_first_vector_name(self) -> Optional[str]: vectors = self._client.get_collection( self._collection_name ).config.params.vectors if not isinstance(vectors, dict): # The collection only has the default, unnamed vector return None first_vector_name = list(vectors.keys())[0] # The collection has multiple vectors. Could also include the falsy unnamed vector - Empty string("") return first_vector_name or None @classmethod def _points_to_output( cls, points: List[models.ScoredPoint], query: str, text_key: str, metadata_key: str, vector_name: Optional[str], ) -> RetrieverOutput: doc_indices = [point.id for point in points] doc_scores = [point.score for point in points] documents = [ cls._doc_from_point(point, text_key, metadata_key, vector_name) for point in points ] return RetrieverOutput( doc_indices=doc_indices, doc_scores=doc_scores, query=query, documents=documents, ) @classmethod def _doc_from_point( cls, point: models.ScoredPoint, text_key: str, metadata_key: str, vector_name: Optional[str] = None, ) -> Document: vector = point.vector if isinstance(vector, dict): vector = vector[vector_name] payload = point.payload.copy() return Document( id=point.id, text=payload.get(text_key, ""), meta_data=payload.get(metadata_key, {}), vector=vector, )