"""AWS Bedrock ModelClient integration."""
import os
from typing import Dict, Optional, Any, Callable
import backoff
import logging
from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput
import boto3
from botocore.config import Config
log = logging.getLogger(__name__)
bedrock_runtime_exceptions = boto3.client(
service_name="bedrock-runtime",
region_name=os.getenv("AWS_REGION_NAME", "us-east-1")
).exceptions
[docs]
def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion['output']['message']['content'][0]['text']
__all__ = ["BedrockAPIClient", "get_first_message_content", "bedrock_runtime_exceptions"]
[docs]
class BedrockAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Bedrock API client.
Visit https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html for more api details.
"""
def __init__(
self,
aws_profile_name=None,
aws_region_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
aws_connection_timeout=None,
aws_read_timeout=None,
chat_completion_parser: Callable = None,
):
super().__init__()
self._aws_profile_name = aws_profile_name
self._aws_region_name = aws_region_name
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._aws_session_token = aws_session_token
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout
self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
[docs]
def init_sync_client(self):
"""
There is no need to pass both profile and secret key and access key. Path one of them.
if the compute power assume a role that have access to bedrock, no need to pass anything.
"""
aws_profile_name = self._aws_profile_name or os.getenv("AWS_PROFILE_NAME")
aws_region_name = self._aws_region_name or os.getenv("AWS_REGION_NAME")
aws_access_key_id = self._aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = self._aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
aws_session_token = self._aws_session_token or os.getenv("AWS_SESSION_TOKEN")
config = None
if self._aws_connection_timeout or self._aws_read_timeout:
config = Config(
connect_timeout=self._aws_connection_timeout, # Connection timeout in seconds
read_timeout=self._aws_read_timeout # Read timeout in seconds
)
session = boto3.Session(
profile_name=aws_profile_name,
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
bedrock_runtime = session.client(service_name="bedrock-runtime", config=config)
return bedrock_runtime
[docs]
def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")
[docs]
def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
try:
data = completion['output']['message']['content'][0]['text']
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
except Exception as e:
log.error(f"Error parsing completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
)
[docs]
def track_completion_usage(self, completion: Dict) -> CompletionUsage:
r"""Track the completion usage."""
usage = completion['usage']
return CompletionUsage(
completion_tokens=usage['outputTokens'],
prompt_tokens=usage['inputTokens'],
total_tokens=usage['totalTokens']
)
[docs]
@backoff.on_exception(
backoff.expo,
(
bedrock_runtime_exceptions.ThrottlingException,
bedrock_runtime_exceptions.ModelTimeoutException,
bedrock_runtime_exceptions.InternalServerException,
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
[docs]
async def acall(self):
raise NotImplementedError("Async call not implemented yet.")