Source code for core.embedder

r"""The component that orchestrates model client (Embedding models in particular) and output processors."""

from typing import Optional, Any, Dict, List
import logging
from tqdm import tqdm

from adalflow.core.types import ModelType, EmbedderOutput
from adalflow.core.model_client import ModelClient
from adalflow.core.types import (
    EmbedderOutputType,
    EmbedderInputType,
    BatchEmbedderInputType,
    BatchEmbedderOutputType,
)
from adalflow.core.component import Component
import adalflow.core.functional as F

__all__ = ["Embedder", "BatchEmbedder"]

log = logging.getLogger(__name__)


[docs] class Embedder(Component): r""" A user-facing component that orchestrates an embedder model via the model client and output processors. Args: model_client (ModelClient): The model client to use for the embedder. model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. output_processors (Optional[Component], optional): The output processors after model call. Defaults to None. If you want to add further processing, it should operate on the ``EmbedderOutput`` data type. input: a single str or a list of str. When a list is used, the list is processed as a batch of inputs in the model client. Note: - The ``output_processors`` will be applied only on the data field of ``EmbedderOutput``, which is a list of ``Embedding``. - Use ``BatchEmbedder`` for automatically batching input of large size, larger than 100. """ model_type: ModelType = ModelType.EMBEDDER model_client: ModelClient output_processors: Optional[Component] def __init__( self, *, model_client: ModelClient, model_kwargs: Dict[str, Any] = {}, output_processors: Optional[Component] = None, ) -> None: super().__init__(model_kwargs=model_kwargs) if not isinstance(model_kwargs, Dict): raise TypeError( f"{type(self).__name__} requires a dictionary for model_kwargs, not a string" ) self.model_kwargs = model_kwargs.copy() if not isinstance(model_client, ModelClient): raise TypeError( f"{type(self).__name__} requires a ModelClient instance for model_client, please pass it as OpenAIClient() or GroqAPIClient() for example." ) self.model_client = model_client self.output_processors = output_processors
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "Embedder": """Create an Embedder from a configuration dictionary. Example: .. code-block:: python embedder_config = { "model_client": { "component_name": "OpenAIClient", "component_config": {} }, "model_kwargs": { "model": "text-embedding-3-small", "dimensions": 256, "encoding_format": "float" } } embedder = Embedder.from_config(embedder_config) """ if "model_client" not in config: raise ValueError("model_client is required in the config") return super().from_config(config)
def _compose_model_kwargs(self, **model_kwargs) -> Dict[str, object]: r"""Add new arguments or overwrite existing arguments in the model_kwargs.""" return F.compose_model_kwargs(self.model_kwargs, model_kwargs) def _pre_call( self, input: EmbedderInputType, model_kwargs: Optional[Dict] = {} ) -> Dict: # step 1: combine the model_kwargs with the default model_kwargs composed_model_kwargs = self._compose_model_kwargs(**model_kwargs) # step 2: convert the input to the api_kwargs api_kwargs = self.model_client.convert_inputs_to_api_kwargs( input=input, model_kwargs=composed_model_kwargs, model_type=self.model_type, ) log.debug(f"api_kwargs: {api_kwargs}") return api_kwargs def _post_call(self, response: Any) -> EmbedderOutputType: r"""Get float list response and process it with output_processor""" try: embedding_output: EmbedderOutputType = ( self.model_client.parse_embedding_response(response) ) except Exception as e: log.error(f"Error parsing the embedding {response}: {e}") return EmbedderOutput(raw_response=str(response), error=str(e)) output: EmbedderOutputType = EmbedderOutputType(raw_response=embedding_output) # data = embedding_output.data if self.output_processors: try: embedding_output = self.output_processors(embedding_output) output.data = embedding_output except Exception as e: log.error(f"Error processing the output: {e}") output.error = str(e) else: output.data = embedding_output.data return output
[docs] def call( self, input: EmbedderInputType, model_kwargs: Optional[Dict] = {}, ) -> EmbedderOutputType: log.debug(f"Calling {self.__class__.__name__} with input: {input}") api_kwargs = self._pre_call(input=input, model_kwargs=model_kwargs) output: EmbedderOutputType = None response = None try: response = self.model_client.call( api_kwargs=api_kwargs, model_type=self.model_type ) except Exception as e: log.error(f"Error calling the model: {e}") output = EmbedderOutput(error=str(e)) if response: try: output = self._post_call(response) except Exception as e: log.error(f"Error processing output: {e}") output = EmbedderOutput(raw_response=str(response), error=str(e)) # add back the input output.input = [input] if isinstance(input, str) else input log.debug(f"Output from {self.__class__.__name__}: {output}") return output
[docs] async def acall( self, input: EmbedderInputType, model_kwargs: Optional[Dict] = {}, ) -> EmbedderOutputType: log.debug(f"Calling {self.__class__.__name__} with input: {input}") api_kwargs = self._pre_call(input=input, model_kwargs=model_kwargs) output: EmbedderOutputType = None response = None try: response = await self.model_client.acall( api_kwargs=api_kwargs, model_type=self.model_type ) except Exception as e: log.error(f"Error calling the model: {e}") output = EmbedderOutput(error=str(e)) if response: try: output = self._post_call(response) except Exception as e: log.error(f"Error processing output: {e}") output = EmbedderOutput(raw_response=str(response), error=str(e)) # add back the input output.input = [input] if isinstance(input, str) else input log.debug(f"Output from {self.__class__.__name__}: {output}") return output
def _extra_repr(self) -> str: s = f"model_kwargs={self.model_kwargs}, " return s
[docs] class BatchEmbedder(Component): __doc__ = r"""Adds batching to the embedder component. Args: embedder (Embedder): The embedder to use for batching. batch_size (int, optional): The batch size to use for batching. Defaults to 100. """ def __init__(self, embedder: Embedder, batch_size: int = 100) -> None: super().__init__(batch_size=batch_size) self.embedder = embedder self.batch_size = batch_size
[docs] def call( self, input: BatchEmbedderInputType, model_kwargs: Optional[Dict] = {} ) -> BatchEmbedderOutputType: r"""Call the embedder with batching. Args: input (BatchEmbedderInputType): The input to the embedder. Use this when you have a large input that needs to be batched. Also ensure the output can fit into memory. model_kwargs (Optional[Dict], optional): The model kwargs to pass to the embedder. Defaults to {}. Returns: BatchEmbedderOutputType: The output from the embedder. """ if isinstance(input, str): input = [input] n = len(input) embeddings: List[EmbedderOutputType] = [] for i in tqdm( range(0, n, self.batch_size), desc="Batch embedding documents", ): batch_input = input[i : i + self.batch_size] batch_output = self.embedder.call( input=batch_input, model_kwargs=model_kwargs ) embeddings.append(batch_output) return embeddings