r"""ModelClient is the protocol and base class for all models(either via APIs or local models) to communicate with components."""fromtypingimportAny,Dict,Optionalfromadalflow.core.componentimportDataComponentfromadalflow.core.typesimport(ModelType,EmbedderOutput,GeneratorOutput,CompletionUsage,)# TODO: global model registry for all available models in users' project.
[docs]classModelClient(DataComponent):__doc__=r"""The protocol and abstract class for all models(either via APIs or local models) to communicate with components.ModelClient is to separate the model API calls from the rest of the system,making it a plug-and-play component that can be used in functional components like Generator and Embedder.For a particular API provider, such as OpenAI, we will have a class that inherits from ModelClient.It does four things:(1) Initialize the client, including both sync and async.(2) Convert the standard AdalFlow components inputs to the API-specific format.(3) Call the API and parse the response.(4) Handle API specific exceptions and errors to retry the call.Check the subclasses in `components/model_client/` directory for the functional API clients we have. This interface is designed to bridge the gap between AdalFlow components inputs and model APIs. You can see examples of the subclasses in components/model_client/ directory. """def__init__(self,*args,**kwargs)->None:r"""Ensure the subclasses will at least call self._init_sync_client() to initialize the sync client."""super().__init__()self.sync_client=Noneself.async_client=None
[docs]definit_sync_client(self):raiseNotImplementedError(f"{type(self).__name__} must implement _init_sync_client method")
[docs]definit_async_client(self):raiseNotImplementedError(f"{type(self).__name__} must implement _init_async_client method")
[docs]defcall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):r"""Subclass use this to call the API with the sync client. model_type: this decides which API, such as chat.completions or embeddings for OpenAI. api_kwargs: all the arguments that the API call needs, subclass should implement this method. Additionally in subclass you can implement the error handling and retry logic here. See OpenAIClient for example. """raiseNotImplementedError(f"{type(self).__name__} must implement _call method")
[docs]asyncdefacall(self,api_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED):r"""Subclass use this to call the API with the async client."""pass
[docs]defconvert_inputs_to_api_kwargs(self,input:Optional[Any]=None,model_kwargs:Dict={},model_type:ModelType=ModelType.UNDEFINED,)->Dict:r""" Bridge the Component's standard input and model_kwargs into API-specific format, the api_kwargs that will be used in _call and _acall methods. All types of models supported by this particular provider should be handled here. Args: input (Optional[Any], optional): input to the model. Defaults to None. model_kwargs (Dict): model kwargs model_type (ModelType): model type """raiseNotImplementedError(f"{type(self).__name__} must implement _combine_input_and_model_kwargs method")
[docs]defparse_chat_completion(self,completion:Any)->"GeneratorOutput":r"""Parse the chat completion to str."""raiseNotImplementedError(f"{type(self).__name__} must implement parse_chat_completion method")
[docs]deftrack_completion_usage(self,*args,**kwargs)->"CompletionUsage":r"""Track the chat completion usage. Use OpenAI standard API for tracking."""raiseNotImplementedError(f"{type(self).__name__} must implement track_usage method")
[docs]defparse_embedding_response(self,response:Any)->"EmbedderOutput":r"""Parse the embedding response to a structure AdalFlow components can understand."""raiseNotImplementedError(f"{type(self).__name__} must implement parse_embedding_response method")
@staticmethoddef_process_text(text:str)->str:"""This is specific to OpenAI API, as removing new lines could have better performance in the embedder"""text=text.replace("\n"," ")returntextdef_track_usage(self,**kwargs):passdef__call__(self,*args,**kwargs):returnsuper().__call__(*args,**kwargs)
[docs]deflist_models(self):"""List all available models from this provider"""raiseNotImplementedError(f"{type(self).__name__} must implement list_models method")