Transformers documentation

FBGEMM FP8

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

FBGEMM FP8

With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):

  • the weights will be quantized in 8bit (FP8) per channel
  • the activation will be quantized in 8bit (FP8) per token

It relies on the FBGEMM library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.

You need a GPU with compute capability>=9 (e.g. H100)

Before you begin, make sure the following libraries are installed with their latest version:

pip install --upgrade accelerate fbgemm-gpu torch

If you are having issues with fbgemm-gpu and torch library, you might need to install the nightly release. You can follow the instruction here

By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set torch_dtype="auto" to load the weights in the data type defined in a model’s config.json file to automatically load the most memory-optimal data type.

from transformers import FbgemmFp8Config, AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = FbgemmFp8Config()
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))

A quantized model can be saved via “saved_pretrained” and be reused again via the “from_pretrained”.

quant_path = "/path/to/save/quantized/model"
model.save_pretrained(quant_path)
model = AutoModelForCausalLM.from_pretrained(quant_path, device_map="auto")
< > Update on GitHub