Source code for components.model_client.cohere_client
"""Cohere ModelClient integration."""
import os
from typing import Dict, Optional, Any
import backoff
from adalflow.utils.lazy_import import safe_import, OptionalPackages
cohere = safe_import(OptionalPackages.COHERE.value[0], OptionalPackages.COHERE.value[1])
from cohere import (
BadRequestError,
InternalServerError,
)
from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType
[docs]
class CohereAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Cohere API.
Visit https://docs.cohere.com/ for more api details.
References:
- Cohere reranker: https://docs.cohere.com/reference/rerank
Tested Cohere models: 6/16/2024
- rerank-english-v3.0, rerank-multilingual-v3.0, rerank-english-v2.0, rerank-multilingual-v2.0
.. note::
For all ModelClient integration, such as CohereAPIClient, if you want to subclass CohereAPIClient, you need to import it from the module directly.
``from adalflow.components.model_client.cohere_client import CohereAPIClient``
instead of using the lazy import with:
``from adalflow.components.model_client import CohereAPIClient``
"""
def __init__(self, api_key: Optional[str] = None):
r"""It is recommended to set the GROQ_API_KEY environment variable instead of passing it as an argument.
Args:
api_key (Optional[str], optional): Groq API key. Defaults to None.
"""
super().__init__()
self._api_key = api_key
self.init_sync_client()
self.async_client = None # only initialize if the async call is called
[docs]
def init_sync_client(self):
api_key = self._api_key or os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError("Environment variable COHERE_API_KEY must be set")
self.sync_client = cohere.Client(api_key=api_key)
[docs]
def init_async_client(self):
api_key = self._api_key or os.getenv("COHERE_API_KEY")
if not api_key:
raise ValueError("Environment variable COHERE_API_KEY must be set")
self.async_client = cohere.AsyncClient(api_key=api_key)
[docs]
@backoff.on_exception(
backoff.expo,
(
BadRequestError,
InternalServerError,
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
assert (
"model" in api_kwargs
), f"model must be specified in api_kwargs: {api_kwargs}"
if (
model_type == ModelType.RERANKER
): # query -> # scores for top_k documents, index for the top_k documents, return as tuple
response = self.sync_client.rerank(**api_kwargs)
top_k_scores = [result.relevance_score for result in response.results]
top_k_indices = [result.index for result in response.results]
return top_k_indices, top_k_scores
else:
raise ValueError(f"model_type {model_type} is not supported")
[docs]
@backoff.on_exception(
backoff.expo,
(
BadRequestError,
InternalServerError,
),
max_time=5,
)
async def acall(
self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED
):
if self.async_client is None:
self.init_async_client()
if "model" not in api_kwargs:
raise ValueError("model must be specified")
if model_type == ModelType.RERANKER:
response = await self.async_client.rerank(**api_kwargs)
top_k_scores = [result.relevance_score for result in response.results]
top_k_indices = [result.index for result in response.results]
return top_k_indices, top_k_scores
else:
raise ValueError(f"model_type {model_type} is not supported")