transformers_client#

Huggingface transformers ModelClient integration.

Functions

average_pool(last_hidden_states, attention_mask)

clean_device_cache()

get_device()

Classes

TransformerEmbedder([model_name])

Local model SDK for transformers.

TransformerLLM([model_name])

Local model SDK for transformers LLM.

TransformerReranker([model_name])

Local model SDK for a reranker model using transformers.

TransformersClient([model_name])

LightRAG API client for transformers.

average_pool(last_hidden_states: Tensor, attention_mask: Tensor) Tensor[source]#
class TransformerEmbedder(model_name: str | None = 'thenlper/gte-base')[source]#

Bases: object

Local model SDK for transformers.

There are two ways to run transformers: (1) model and then run model inference (2) Pipeline and then run pipeline inference

This file demonstrates how to (1) create a torch model inference component: TransformerEmbedder which equalize to OpenAI(), the SyncAPIClient (2) Convert this model inference component to LightRAG API client: TransformersClient

The is now just an exmplary component that initialize a certain model from transformers and run inference on it. It is not tested on all transformer models yet. It might be necessary to write one for each model.

References: - transformers: https://huggingface.co/docs/transformers/en/index - thenlper/gte-base model:https://huggingface.co/thenlper/gte-base

models: Dict[str, type] = {}#
init_model(model_name: str)[source]#
infer_gte_base_embedding(input=typing.Union[str, typing.List[str]], tolist: bool = True)[source]#
get_device()[source]#
clean_device_cache()[source]#
class TransformerReranker(model_name: str | None = 'BAAI/bge-reranker-base')[source]#

Bases: object

Local model SDK for a reranker model using transformers.

References: - model: https://huggingface.co/BAAI/bge-reranker-base - paper: https://arxiv.org/abs/2309.07597

note: If you are using Macbook M1 series chips, you need to ensure torch.device("mps") is set.

models: Dict[str, type] = {}#
init_model(model_name: str)[source]#
infer_bge_reranker_base(query: str, documents: List[str]) List[float][source]#
class TransformerLLM(model_name: str | None = None)[source]#

Bases: object

Local model SDK for transformers LLM.

Note

This inference component is only specific to the HuggingFaceH4/zephyr-7b-beta model.

The example raw output: # <|system|> # You are a friendly chatbot who always responds in the style of a pirate.</s> # <|user|> # How many helicopters can a human eat in one sitting?</s> # <|assistant|> # Ah, me hearty matey! But yer question be a puzzler! A human cannot eat a helicopter in one sitting, as helicopters are not edible. They be made of metal, plastic, and other materials, not food!

References: - model: https://huggingface.co/HuggingFaceH4/zephyr-7b-beta - https://huggingface.co/google/gemma-2b - https://huggingface.co/google/gemma-2-2b

models: Dict[str, type] = {}#
tokenizer: Dict[str, type] = {}#
model_to_init_func = {'HuggingFaceH4/zephyr-7b-beta': 'use_pipeline', 'google/gemma-2-2b': 'use_pipeline'}#
init_model(model_name: str)[source]#
parse_chat_completion(completion: Any) str[source]#
infer_llm(*, model: str, messages: Sequence[Dict[str, str]], max_tokens: int | None = None, **kwargs)[source]#
class TransformersClient(model_name: str | None = None)[source]#

Bases: ModelClient

LightRAG API client for transformers.

Use: ``ls ~/.cache/huggingface/hub `` to see the cached models.

Some modeles are gated, you will need to their page to get the access token. Find how to apply tokens here: https://huggingface.co/docs/hub/security-tokens Once you have a token and have access, put the token in the environment variable HF_TOKEN.

support_models = {'BAAI/bge-reranker-base': {'type': ModelType.RERANKER}, 'HuggingFaceH4/zephyr-7b-beta': {'type': ModelType.LLM}, 'google/gemma-2-2b': {'type': ModelType.LLM}, 'thenlper/gte-base': {'type': ModelType.EMBEDDER}}#
init_sync_client()[source]#
init_reranker_client()[source]#
init_llm_client()[source]#
set_llm_client(llm_client: object)[source]#

Allow user to pass a custom llm client. Here is an example of a custom llm client:

Ensure you have parse_chat_completion and __call__ methods which will be applied to api_kwargs specified in transform_client.call().

class CustomizeLLM:

    def __init__(self) -> None:
        pass

    def parse_chat_completion(self, completion: Any) -> str:
        return completion

    def __call__(self, messages: Sequence[Dict[str, str]], model: str, **kwargs):
        from transformers import AutoTokenizer, AutoModelForCausalLM

        tokenizer = AutoTokenizer.from_pretrained(
            "deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            "deepseek-ai/deepseek-coder-1.3b-instruct",
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
        ).to(get_device())
        messages = [
            {"role": "user", "content": "write a quick sort algorithm in python."}
        ]
        inputs = tokenizer.apply_chat_template(
            messages, add_generation_prompt=True, return_tensors="pt"
        ).to(model.device)
        # tokenizer.eos_token_id is the id of <|EOT|> token
        outputs = model.generate(
            inputs,
            max_new_tokens=512,
            do_sample=False,
            top_k=50,
            top_p=0.95,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
        )
        print(
            tokenizer.decode(outputs[0][len(inputs[0]) :], skip_special_tokens=True)
        )
        decoded_outputs = []
        for output in outputs:
            decoded_outputs.append(
                tokenizer.decode(output[len(inputs[0]) :], skip_special_tokens=True)
            )
        return decoded_outputs

llm_client = CustomizeLLM()
transformer_client.set_llm_client(llm_client)
# use in the generator
generator = Generator(
    model_client=transformer_client,
    model_kwargs=model_kwargs,
    prompt_kwargs=prompt_kwargs,
    ...)
parse_embedding_response(response: Any) EmbedderOutput[source]#

Parse the embedding response to a structure LightRAG components can understand.

parse_chat_completion(completion: Any) GeneratorOutput[source]#

Parse the chat completion to str.

call(api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED)[source]#

Subclass use this to call the API with the sync client. model_type: this decides which API, such as chat.completions or embeddings for OpenAI. api_kwargs: all the arguments that the API call needs, subclass should implement this method.

Additionally in subclass you can implement the error handling and retry logic here. See OpenAIClient for example.

convert_inputs_to_api_kwargs(input: Any, model_kwargs: dict = {}, model_type: ModelType = ModelType.UNDEFINED) dict[source]#

Bridge the Component’s standard input and model_kwargs into API-specific format, the api_kwargs that will be used in _call and _acall methods.

All types of models supported by this particular provider should be handled here. :param input: input to the model. Defaults to None. :type input: Optional[Any], optional :param model_kwargs: model kwargs :type model_kwargs: Dict :param model_type: model type :type model_type: ModelType