Source code for components.model_client.openai_client

"""OpenAI ModelClient integration."""

import os
import base64
from typing import (
    Dict,
    Sequence,
    Optional,
    List,
    Any,
    TypeVar,
    Callable,
    Generator,
    Union,
    Literal,
)
import re

import logging
import backoff

# optional import
from adalflow.utils.lazy_import import safe_import, OptionalPackages


openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

from openai import OpenAI, AsyncOpenAI, Stream
from openai import (
    APITimeoutError,
    InternalServerError,
    RateLimitError,
    UnprocessableEntityError,
    BadRequestError,
)
from openai.types import (
    Completion,
    CreateEmbeddingResponse,
    Image,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletion

from adalflow.core.model_client import ModelClient
from adalflow.core.types import (
    ModelType,
    EmbedderOutput,
    TokenLogProb,
    CompletionUsage,
    GeneratorOutput,
)
from adalflow.components.model_client.utils import parse_embedding_response

log = logging.getLogger(__name__)
T = TypeVar("T")


# completion parsing functions and you can combine them into one singple chat completion parser
[docs] def get_first_message_content(completion: ChatCompletion) -> str: r"""When we only need the content of the first message. It is the default parser for chat completion.""" log.debug(f"raw completion: {completion}") return completion.choices[0].message.content
# def _get_chat_completion_usage(completion: ChatCompletion) -> OpenAICompletionUsage: # return completion.usage # A simple heuristic to estimate token count for estimating number of tokens in a Streaming response
[docs] def estimate_token_count(text: str) -> int: """ Estimate the token count of a given text. Args: text (str): The text to estimate token count for. Returns: int: Estimated token count. """ # Split the text into tokens using spaces as a simple heuristic tokens = text.split() # Return the number of tokens return len(tokens)
[docs] def parse_stream_response(completion: ChatCompletionChunk) -> str: r"""Parse the response of the stream API.""" return completion.choices[0].delta.content
[docs] def handle_streaming_response(generator: Stream[ChatCompletionChunk]): r"""Handle the streaming response.""" for completion in generator: log.debug(f"Raw chunk completion: {completion}") parsed_content = parse_stream_response(completion) yield parsed_content
[docs] def get_all_messages_content(completion: ChatCompletion) -> List[str]: r"""When the n > 1, get all the messages content.""" return [c.message.content for c in completion.choices]
[docs] def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]: r"""Get the probabilities of each token in the completion.""" log_probs = [] for c in completion.choices: content = c.logprobs.content print(content) log_probs_for_choice = [] for openai_token_logprob in content: token = openai_token_logprob.token logprob = openai_token_logprob.logprob log_probs_for_choice.append(TokenLogProb(token=token, logprob=logprob)) log_probs.append(log_probs_for_choice) return log_probs
[docs] class OpenAIClient(ModelClient): __doc__ = r"""A component wrapper for the OpenAI API client. Supports both embedding and chat completion APIs, including multimodal capabilities. Users can: 1. Simplify use of ``Embedder`` and ``Generator`` components by passing `OpenAIClient()` as the `model_client`. 2. Use this as a reference to create their own API client or extend this class by copying and modifying the code. Note: We recommend avoiding `response_format` to enforce output data type or `tools` and `tool_choice` in `model_kwargs` when calling the API. OpenAI's internal formatting and added prompts are unknown. Instead: - Use :ref:`OutputParser<components-output_parsers>` for response parsing and formatting. For multimodal inputs, provide images in `model_kwargs["images"]` as a path, URL, or list of them. The model must support vision capabilities (e.g., `gpt-4o`, `gpt-4o-mini`, `o1`, `o1-mini`). For image generation, use `model_type=ModelType.IMAGE_GENERATION` and provide: - model: `"dall-e-3"` or `"dall-e-2"` - prompt: Text description of the image to generate - size: `"1024x1024"`, `"1024x1792"`, or `"1792x1024"` for DALL-E 3; `"256x256"`, `"512x512"`, or `"1024x1024"` for DALL-E 2 - quality: `"standard"` or `"hd"` (DALL-E 3 only) - n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) - response_format: `"url"` or `"b64_json"` Args: api_key (Optional[str], optional): OpenAI API key. Defaults to `None`. chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion into a `str`. Defaults to `None`. The default parser is `get_first_message_content`. base_url (str): The API base URL to use when initializing the client. Defaults to `"https://api.openai.com"`, but can be customized for third-party API providers or self-hosted models. env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. References: - OpenAI API Overview: https://platform.openai.com/docs/introduction - Embeddings Guide: https://platform.openai.com/docs/guides/embeddings - Chat Completion Models: https://platform.openai.com/docs/guides/text-generation - Vision Models: https://platform.openai.com/docs/guides/vision - Image Generation: https://platform.openai.com/docs/guides/images """ def __init__( self, api_key: Optional[str] = None, chat_completion_parser: Callable[[Completion], Any] = None, input_type: Literal["text", "messages"] = "text", base_url: str = "https://api.openai.com/v1/", env_api_key_name: str = "OPENAI_API_KEY", ): r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument. Args: api_key (Optional[str], optional): OpenAI API key. Defaults to None. base_url (str): The API base URL to use when initializing the client. env_api_key_name (str): The environment variable name for the API key. Defaults to `"OPENAI_API_KEY"`. """ super().__init__() self._api_key = api_key self._env_api_key_name = env_api_key_name self.base_url = base_url self.sync_client = self.init_sync_client() self.async_client = None # only initialize if the async call is called self.chat_completion_parser = ( chat_completion_parser or get_first_message_content ) self._input_type = input_type self._api_kwargs = {} # add api kwargs when the OpenAI Client is called
[docs] def init_sync_client(self): api_key = self._api_key or os.getenv(self._env_api_key_name) if not api_key: raise ValueError( f"Environment variable {self._env_api_key_name} must be set" ) return OpenAI(api_key=api_key, base_url=self.base_url)
[docs] def init_async_client(self): api_key = self._api_key or os.getenv(self._env_api_key_name) if not api_key: raise ValueError( f"Environment variable {self._env_api_key_name} must be set" ) return AsyncOpenAI(api_key=api_key, base_url=self.base_url)
# def _parse_chat_completion(self, completion: ChatCompletion) -> "GeneratorOutput": # # TODO: raw output it is better to save the whole completion as a source of truth instead of just the message # try: # data = self.chat_completion_parser(completion) # usage = self.track_completion_usage(completion) # return GeneratorOutput( # data=data, error=None, raw_response=str(data), usage=usage # ) # except Exception as e: # log.error(f"Error parsing the completion: {e}") # return GeneratorOutput(data=None, error=str(e), raw_response=completion)
[docs] def parse_chat_completion( self, completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], ) -> "GeneratorOutput": """Parse the completion, and put it into the raw_response.""" log.debug(f"completion: {completion}, parser: {self.chat_completion_parser}") try: data = self.chat_completion_parser(completion) except Exception as e: log.error(f"Error parsing the completion: {e}") return GeneratorOutput(data=None, error=str(e), raw_response=completion) try: usage = self.track_completion_usage(completion) return GeneratorOutput( data=None, error=None, raw_response=data, usage=usage ) except Exception as e: log.error(f"Error tracking the completion usage: {e}") return GeneratorOutput(data=None, error=str(e), raw_response=data)
[docs] def track_completion_usage( self, completion: Union[ChatCompletion, Generator[ChatCompletionChunk, None, None]], ) -> CompletionUsage: try: usage: CompletionUsage = CompletionUsage( completion_tokens=completion.usage.completion_tokens, prompt_tokens=completion.usage.prompt_tokens, total_tokens=completion.usage.total_tokens, ) return usage except Exception as e: log.error(f"Error tracking the completion usage: {e}") return CompletionUsage( completion_tokens=None, prompt_tokens=None, total_tokens=None )
[docs] def parse_embedding_response( self, response: CreateEmbeddingResponse ) -> EmbedderOutput: r"""Parse the embedding response to a structure Adalflow components can understand. Should be called in ``Embedder``. """ try: return parse_embedding_response(response) except Exception as e: log.error(f"Error parsing the embedding response: {e}") return EmbedderOutput(data=[], error=str(e), raw_response=response)
[docs] def convert_inputs_to_api_kwargs( self, input: Optional[Any] = None, model_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED, ) -> Dict: r""" Specify the API input type and output api_kwargs that will be used in _call and _acall methods. Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format. For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them. The model specified in model_kwargs["model"] must support multimodal capabilities when using images. Args: input: The input text or messages to process model_kwargs: Additional parameters including: - images: Optional image source(s) as path, URL, or list of them - detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' - model: The model to use (must support multimodal inputs if images are provided) model_type: The type of model (EMBEDDER or LLM) Returns: Dict: API-specific kwargs for the model call """ final_model_kwargs = model_kwargs.copy() if model_type == ModelType.EMBEDDER: if isinstance(input, str): input = [input] # convert input to input if not isinstance(input, Sequence): raise TypeError("input must be a sequence of text") final_model_kwargs["input"] = input elif model_type == ModelType.LLM: # convert input to messages messages: List[Dict[str, str]] = [] images = final_model_kwargs.pop("images", None) detail = final_model_kwargs.pop("detail", "auto") if self._input_type == "messages": system_start_tag = "<START_OF_SYSTEM_PROMPT>" system_end_tag = "<END_OF_SYSTEM_PROMPT>" user_start_tag = "<START_OF_USER_PROMPT>" user_end_tag = "<END_OF_USER_PROMPT>" # new regex pattern to ignore special characters such as \n pattern = ( rf"{system_start_tag}\s*(.*?)\s*{system_end_tag}\s*" rf"{user_start_tag}\s*(.*?)\s*{user_end_tag}" ) # Compile the regular expression # re.DOTALL is to allow . to match newline so that (.*?) does not match in a single line regex = re.compile(pattern, re.DOTALL) # Match the pattern match = regex.match(input) system_prompt, input_str = None, None if match: system_prompt = match.group(1) input_str = match.group(2) else: print("No match found.") if system_prompt and input_str: messages.append({"role": "system", "content": system_prompt}) if images: content = [{"type": "text", "text": input_str}] if isinstance(images, (str, dict)): images = [images] for img in images: content.append(self._prepare_image_content(img, detail)) messages.append({"role": "user", "content": content}) else: messages.append({"role": "user", "content": input_str}) if len(messages) == 0: if images: content = [{"type": "text", "text": input}] if isinstance(images, (str, dict)): images = [images] for img in images: content.append(self._prepare_image_content(img, detail)) messages.append({"role": "user", "content": content}) else: messages.append({"role": "system", "content": input}) final_model_kwargs["messages"] = messages elif model_type == ModelType.IMAGE_GENERATION: # For image generation, input is the prompt final_model_kwargs["prompt"] = input # Ensure model is specified if "model" not in final_model_kwargs: raise ValueError("model must be specified for image generation") # Set defaults for DALL-E 3 if not specified final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") final_model_kwargs["quality"] = final_model_kwargs.get( "quality", "standard" ) final_model_kwargs["n"] = final_model_kwargs.get("n", 1) final_model_kwargs["response_format"] = final_model_kwargs.get( "response_format", "url" ) # Handle image edits and variations image = final_model_kwargs.get("image") if isinstance(image, str) and os.path.isfile(image): final_model_kwargs["image"] = self._encode_image(image) mask = final_model_kwargs.get("mask") if isinstance(mask, str) and os.path.isfile(mask): final_model_kwargs["mask"] = self._encode_image(mask) else: raise ValueError(f"model_type {model_type} is not supported") return final_model_kwargs
[docs] def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput: """Parse the image generation response into a GeneratorOutput.""" try: # Extract URLs or base64 data from the response data = [img.url or img.b64_json for img in response] # For single image responses, unwrap from list if len(data) == 1: data = data[0] return GeneratorOutput( data=data, raw_response=str(response), ) except Exception as e: log.error(f"Error parsing image generation response: {e}") return GeneratorOutput(data=None, error=str(e), raw_response=str(response))
[docs] @backoff.on_exception( backoff.expo, ( APITimeoutError, InternalServerError, RateLimitError, UnprocessableEntityError, BadRequestError, ), max_time=5, ) def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): """ kwargs is the combined input and model_kwargs. Support streaming call. """ log.info(f"api_kwargs: {api_kwargs}") self._api_kwargs = api_kwargs if model_type == ModelType.EMBEDDER: return self.sync_client.embeddings.create(**api_kwargs) elif model_type == ModelType.LLM: if "stream" in api_kwargs and api_kwargs.get("stream", False): log.debug("streaming call") self.chat_completion_parser = handle_streaming_response return self.sync_client.chat.completions.create(**api_kwargs) return self.sync_client.chat.completions.create(**api_kwargs) elif model_type == ModelType.IMAGE_GENERATION: # Determine which image API to call based on the presence of image/mask if "image" in api_kwargs: if "mask" in api_kwargs: # Image edit response = self.sync_client.images.edit(**api_kwargs) else: # Image variation response = self.sync_client.images.create_variation(**api_kwargs) else: # Image generation response = self.sync_client.images.generate(**api_kwargs) return response.data else: raise ValueError(f"model_type {model_type} is not supported")
[docs] @backoff.on_exception( backoff.expo, ( APITimeoutError, InternalServerError, RateLimitError, UnprocessableEntityError, BadRequestError, ), max_time=5, ) async def acall( self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED ): """ kwargs is the combined input and model_kwargs """ # store the api kwargs in the client self._api_kwargs = api_kwargs if self.async_client is None: self.async_client = self.init_async_client() if model_type == ModelType.EMBEDDER: return await self.async_client.embeddings.create(**api_kwargs) elif model_type == ModelType.LLM: return await self.async_client.chat.completions.create(**api_kwargs) elif model_type == ModelType.IMAGE_GENERATION: # Determine which image API to call based on the presence of image/mask if "image" in api_kwargs: if "mask" in api_kwargs: # Image edit response = await self.async_client.images.edit(**api_kwargs) else: # Image variation response = await self.async_client.images.create_variation( **api_kwargs ) else: # Image generation response = await self.async_client.images.generate(**api_kwargs) return response.data else: raise ValueError(f"model_type {model_type} is not supported")
[docs] @classmethod def from_dict(cls: type[T], data: Dict[str, Any]) -> T: obj = super().from_dict(data) # recreate the existing clients obj.sync_client = obj.init_sync_client() obj.async_client = obj.init_async_client() return obj
[docs] def to_dict(self) -> Dict[str, Any]: r"""Convert the component to a dictionary.""" # TODO: not exclude but save yes or no for recreating the clients exclude = [ "sync_client", "async_client", ] # unserializable object output = super().to_dict(exclude=exclude) return output
def _encode_image(self, image_path: str) -> str: """Encode image to base64 string. Args: image_path: Path to image file. Returns: Base64 encoded image string. Raises: ValueError: If the file cannot be read or doesn't exist. """ try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") except FileNotFoundError: raise ValueError(f"Image file not found: {image_path}") except PermissionError: raise ValueError(f"Permission denied when reading image file: {image_path}") except Exception as e: raise ValueError(f"Error encoding image {image_path}: {str(e)}") def _prepare_image_content( self, image_source: Union[str, Dict[str, Any]], detail: str = "auto" ) -> Dict[str, Any]: """Prepare image content for API request. Args: image_source: Either a path to local image or a URL. detail: Image detail level ('auto', 'low', or 'high'). Returns: Formatted image content for API request. """ if isinstance(image_source, str): if image_source.startswith(("http://", "https://")): return { "type": "image_url", "image_url": {"url": image_source, "detail": detail}, } else: base64_image = self._encode_image(image_source) return { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", "detail": detail, }, } return image_source
# Example usage: if __name__ == "__main__": from adalflow.core import Generator from adalflow.utils import setup_env # log = get_logger(level="DEBUG") setup_env() prompt_kwargs = {"input_str": "What is the meaning of life?"} gen = Generator( model_client=OpenAIClient(), model_kwargs={"model": "gpt-3.5-turbo", "stream": False}, ) gen_response = gen(prompt_kwargs) print(f"gen_response: {gen_response}") # for genout in gen_response.data: # print(f"genout: {genout}") # test that to_dict and from_dict works # model_client = OpenAIClient() # model_client_dict = model_client.to_dict() # from_dict_model_client = OpenAIClient.from_dict(model_client_dict) # assert model_client_dict == from_dict_model_client.to_dict() if __name__ == "__main__": import adalflow as adal # setup env or pass the api_key from adalflow.utils import setup_env setup_env() openai_llm = adal.Generator( model_client=adal.OpenAIClient(), model_kwargs={"model": "gpt-3.5-turbo"} ) resopnse = openai_llm(prompt_kwargs={"input_str": "What is LLM?"}) print(resopnse)