[docs]classCohereAPIClient(ModelClient):__doc__=r"""A component wrapper for the Cohere API. Visit https://docs.cohere.com/ for more api details. References: - Cohere reranker: https://docs.cohere.com/reference/rerank Tested Cohere models: 6/16/2024 - rerank-english-v3.0, rerank-multilingual-v3.0, rerank-english-v2.0, rerank-multilingual-v2.0 .. note:: For all ModelClient integration, such as CohereAPIClient, if you want to subclass CohereAPIClient, you need to import it from the module directly. ``from adalflow.components.model_client.cohere_client import CohereAPIClient`` instead of using the lazy import with: ``from adalflow.components.model_client import CohereAPIClient`` """def__init__(self,api_key:Optional[str]=None):r"""It is recommended to set the GROQ_API_KEY environment variable instead of passing it as an argument. Args: api_key (Optional[str], optional): Groq API key. Defaults to None. """super().__init__()self._api_key=api_keyself.init_sync_client()self.async_client=None# only initialize if the async call is called
[docs]definit_sync_client(self):api_key=self._api_keyoros.getenv("COHERE_API_KEY")ifnotapi_key:raiseValueError("Environment variable COHERE_API_KEY must be set")self.sync_client=cohere.Client(api_key=api_key)
[docs]definit_async_client(self):api_key=self._api_keyoros.getenv("COHERE_API_KEY")ifnotapi_key:raiseValueError("Environment variable COHERE_API_KEY must be set")self.async_client=cohere.AsyncClient(api_key=api_key)
[docs]defconvert_inputs_to_api_kwargs(self,input:Optional[Any]=None,# for retriever, it is a list of string.model_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED,)->Dict:r""" For rerank model, expect model_kwargs to have the following keys: model: str, query: str, documents: List[str], top_n: int, """final_model_kwargs=model_kwargs.copy()ifmodel_type==ModelType.RERANKER:final_model_kwargs["query"]=inputif"model"notinfinal_model_kwargs:raiseValueError("model must be specified")if"documents"notinfinal_model_kwargs:raiseValueError("documents must be specified")if"top_k"notinfinal_model_kwargs:raiseValueError("top_k must be specified")# convert top_k to the api specific, which is top_nfinal_model_kwargs["top_n"]=final_model_kwargs.pop("top_k")returnfinal_model_kwargselse:raiseValueError(f"model_type {model_type} is not supported")
[docs]@backoff.on_exception(backoff.expo,(BadRequestError,InternalServerError,),max_time=5,)defcall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):assert("model"inapi_kwargs),f"model must be specified in api_kwargs: {api_kwargs}"if(model_type==ModelType.RERANKER):# query -> # scores for top_k documents, index for the top_k documents, return as tupleresponse=self.sync_client.rerank(**api_kwargs)top_k_scores=[result.relevance_scoreforresultinresponse.results]top_k_indices=[result.indexforresultinresponse.results]returntop_k_indices,top_k_scoreselse:raiseValueError(f"model_type {model_type} is not supported")
[docs]@backoff.on_exception(backoff.expo,(BadRequestError,InternalServerError,),max_time=5,)asyncdefacall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):ifself.async_clientisNone:self.init_async_client()if"model"notinapi_kwargs:raiseValueError("model must be specified")ifmodel_type==ModelType.RERANKER:response=awaitself.async_client.rerank(**api_kwargs)top_k_scores=[result.relevance_scoreforresultinresponse.results]top_k_indices=[result.indexforresultinresponse.results]returntop_k_indices,top_k_scoreselse:raiseValueError(f"model_type {model_type} is not supported")