"""Anthropic ModelClient integration."""
import os
from typing import Dict, Optional, Any, Callable
import backoff
import logging
from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
# optional import
from adalflow.utils.lazy_import import safe_import, OptionalPackages
anthropic = safe_import(
OptionalPackages.ANTHROPIC.value[0], OptionalPackages.ANTHROPIC.value[1]
)
import anthropic
from anthropic import (
RateLimitError,
APITimeoutError,
InternalServerError,
UnprocessableEntityError,
BadRequestError,
)
from anthropic.types import Message, Usage
log = logging.getLogger(__name__)
[docs]
def get_first_message_content(completion: Message) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion.content[0].text
__all__ = ["AnthropicAPIClient", "get_first_message_content"]
# NOTE: using customize parser might make the new_component more complex when we have to handle a callable
[docs]
class AnthropicAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Anthropic API client.
Visit https://docs.anthropic.com/en/docs/intro-to-claude for more api details.
Ensure "max_tokens" are set.
Reference: 8/1/2024
- https://docs.anthropic.com/en/docs/about-claude/models
"""
def __init__(
self,
api_key: Optional[str] = None,
chat_completion_parser: Callable[[Message], Any] = None,
):
r"""It is recommended to set the ANTHROPIC_API_KEY environment variable instead of passing it as an argument."""
super().__init__()
self._api_key = api_key
self.sync_client = self.init_sync_client()
self.async_client = None # only initialize if the async call is called
self.tested_llm_models = ["claude-3-opus-20240229"]
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
[docs]
def init_sync_client(self):
api_key = self._api_key or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Environment variable ANTHROPIC_API_KEY must be set")
return anthropic.Anthropic(api_key=api_key)
[docs]
def init_async_client(self):
api_key = self._api_key or os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("Environment variable ANTHROPIC_API_KEY must be set")
return anthropic.AsyncAnthropic(api_key=api_key)
[docs]
def parse_chat_completion(self, completion: Message) -> GeneratorOutput:
log.debug(f"completion: {completion}")
try:
data = completion.content[0].text
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
except Exception as e:
log.error(f"Error parsing completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
)
[docs]
def track_completion_usage(self, completion: Message) -> CompletionUsage:
r"""Track the completion usage."""
usage: Usage = completion.usage
return CompletionUsage(
completion_tokens=usage.output_tokens,
prompt_tokens=usage.input_tokens,
total_tokens=usage.output_tokens + usage.input_tokens,
)
# TODO: potentially use <SYS></SYS> to separate the system and user messages. This requires user to follow it. If it is not found, then we will only use user message.
[docs]
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
BadRequestError,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.EMBEDDER:
raise ValueError(f"Model type {model_type} not supported")
elif model_type == ModelType.LLM:
return self.sync_client.messages.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
[docs]
@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
BadRequestError,
),
max_time=5,
)
async def acall(
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
):
"""
kwargs is the combined input and model_kwargs
"""
if self.async_client is None:
self.async_client = self.init_async_client()
if model_type == ModelType.EMBEDDER:
raise ValueError(f"Model type {model_type} not supported")
elif model_type == ModelType.LLM:
return await self.async_client.messages.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")