"""Groq ModelClient integration."""
import os
from typing import Dict, Sequence, Optional, Any, TypeVar
import backoff
import logging
from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
from adalflow.utils.lazy_import import safe_import, OptionalPackages
# optional import
groq = safe_import(OptionalPackages.GROQ.value[0], OptionalPackages.GROQ.value[1])
from groq import Groq, AsyncGroq
from groq import (
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
)
from groq.types.chat import ChatCompletion as GroqChatCompletion
from groq.types import CompletionUsage as GroqCompletionUsage
T = TypeVar("T")
log = logging.getLogger(__name__)
[docs]
class GroqAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Groq API client.
Visit https://console.groq.com/docs/ for more api details.
Check https://console.groq.com/docs/models for the available models.
Tested Groq models: 4/22/2024
- llama3-8b-8192
- llama3-70b-8192
- mixtral-8x7b-32768
- gemma-7b-it
"""
def __init__(self, api_key: Optional[str] = None):
r"""It is recommended to set the GROQ_API_KEY environment variable instead of passing it as an argument.
Args:
api_key (Optional[str], optional): Groq API key. Defaults to None.
"""
super().__init__()
self._api_key = api_key
self.sync_client = self.init_sync_client()
self.async_client = None # only initialize if the async call is called
[docs]
def init_sync_client(self):
api_key = self._api_key or os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("Environment variable GROQ_API_KEY must be set")
return Groq(api_key=api_key)
[docs]
def init_async_client(self):
api_key = self._api_key or os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError("Environment variable GROQ_API_KEY must be set")
return AsyncGroq(api_key=api_key)
[docs]
def parse_chat_completion(
self, completion: "GroqChatCompletion"
) -> "GeneratorOutput":
"""
Parse the completion to a string output.
"""
log.debug(f"completion: {completion}")
try:
data = completion.choices[0].message.content
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
except Exception as e:
log.error(f"Error parsing completion: {e}")
return GeneratorOutput(data=str(e), usage=None, raw_response=completion)
[docs]
def track_completion_usage(self, completion: Any) -> CompletionUsage:
usage: GroqCompletionUsage = completion.usage
return CompletionUsage(
completion_tokens=usage.completion_tokens,
prompt_tokens=usage.prompt_tokens,
total_tokens=usage.total_tokens,
)
[docs]
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
assert (
"model" in api_kwargs
), f"model must be specified in api_kwargs: {api_kwargs}"
if model_type == ModelType.LLM:
completion = self.sync_client.chat.completions.create(**api_kwargs)
return completion
else:
raise ValueError(f"model_type {model_type} is not supported")
[docs]
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
),
max_time=5,
)
async def acall(
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
):
if self.async_client is None:
self.async_client = self.init_async_client()
assert "model" in api_kwargs, "model must be specified"
if model_type == ModelType.LLM:
completion = await self.async_client.chat.completions.create(**api_kwargs)
return completion
else:
raise ValueError(f"model_type {model_type} is not supported")
[docs]
@classmethod
def from_dict(cls: type[T], data: Dict[str, Any]) -> T:
obj = super().from_dict(data)
# recreate the existing clients
obj.sync_client = obj.init_sync_client()
obj.async_client = obj.init_async_client()
return obj
[docs]
def to_dict(self) -> Dict[str, Any]:
r"""Convert the component to a dictionary."""
# TODO: not exclude but save yes or no for recreating the clients
exclude = [
"sync_client",
"async_client",
] # unserializable object
output = super().to_dict(exclude=exclude)
return output