Transformers documentation

Bitsandbytes

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v4.57.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Bitsandbytes

The bitsandbytes library provides quantization tools for LLMs through a lightweight Python wrapper around hardware accelerator functions. It enables working with large models using limited computational resources by reducing their memory footprint.

At its core, bitsandbytes provides:

  • Quantized Linear Layers: Linear8bitLt and Linear4bit layers that replace standard PyTorch linear layers with memory-efficient quantized alternatives
  • Optimized Optimizers: 8-bit versions of common optimizers through its optim module, enabling training of large models with reduced memory requirements
  • Matrix Multiplication: Optimized matrix multiplication operations that leverage the quantized format

bitsandbytes offers two main quantization features:

  1. LLM.int8() - An 8-bit quantization method that makes inference more accessible without significant performance degradation. Unlike naive quantization, LLM.int8() dynamically preserves higher precision for critical computations, preventing information loss in sensitive parts of the model.

  2. QLoRA - A 4-bit quantization technique that compresses models even further while maintaining trainability by inserting a small set of trainable low-rank adaptation (LoRA) weights.

Note: For a user-friendly quantization experience, you can use the bitsandbytes community space.

Run the command below to install bitsandbytes.

pip install --upgrade transformers accelerate bitsandbytes

To compile from source, follow the instructions in the bitsandbytes installation guide.

Hardware Compatibility

bitsandbytes is supported on NVIDIA GPUs for CUDA versions 11.8 - 13.0, Intel XPU, Intel Gaudi (HPU), and CPU. There is an ongoing effort to support additional platforms. If you’re interested in providing feedback or testing, check out the bitsandbytes repository for more information.

NVIDIA GPUs (CUDA)

This backend is supported on Linux x86-64, Linux aarch64, and Windows platforms.

Feature Minimum Hardware Requirement
8-bit optimizers NVIDIA Pascal (GTX 10X0 series, P100) or newer GPUs *
LLM.int8() NVIDIA Turing (RTX 20X0 series, T4) or newer GPUs
NF4/FP4 quantization NVIDIA Pascal (GTX 10X0 series, P100) or newer GPUs *

Intel GPUs (XPU)

This backend is supported on Linux x86-64 and Windows x86-64 platforms.

Intel Gaudi (HPU)

This backend is supported on Linux x86-64 for Gaudi2 and Gaudi3.

CPU

This backend is supported on Linux x86-64, Linux aarch64, and Windows x86-64 platforms.

Quantization Examples

Quantize a model by passing a BitsAndBytesConfig to from_pretrained(). This works for any model in any modality, as long as it supports Accelerate and contains torch.nn.Linear layers.

8-bit
4-bit
Quantizing a model in 8-bit halves the memory-usage, and for large models, set `device_map="auto"` to efficiently distribute the weights across all available GPUs.
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = AutoModelForCausalLM.from_pretrained(
    "bigscience/bloom-1b7", 
    device_map="auto",
    quantization_config=quantization_config
)

By default, all other modules such as torch.nn.LayerNorm are set to the default torch dtype. You can change the data type of these modules with the dtype parameter. Setting dtype="auto" loads the model in the data type defined in a model’s config.json file.

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = AutoModelForCausalLM.from_pretrained(
    "facebook/opt-350m", 
    device_map="auto",
    quantization_config=quantization_config, 
    dtype="auto"
)
model_8bit.model.decoder.layers[-1].final_layer_norm.weight.dtype

Once a model is quantized to 8-bit, you can’t push the quantized weights to the Hub unless you’re using the latest version of Transformers and bitsandbytes. If you have the latest versions, then you can push the 8-bit model to the Hub with push_to_hub(). The quantization config.json file is pushed first, followed by the quantized model weights.

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model = AutoModelForCausalLM.from_pretrained(
    "bigscience/bloom-560m", 
    device_map="auto",
    quantization_config=quantization_config
)

model.push_to_hub("bloom-560m-8bit")

8 and 4-bit training is only supported for training extra parameters.

Check your memory footprint with get_memory_footprint.

print(model.get_memory_footprint())

Load quantized models with from_pretrained() without a quantization_config.

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("{your_username}/bloom-560m-8bit", device_map="auto")

LLM.int8

This section explores some of the specific features of 8-bit quantization, such as offloading, outlier thresholds, skipping module conversion, and finetuning.

Offloading

8-bit models can offload weights between the CPU and GPU to fit very large models into memory. The weights dispatched to the CPU are stored in float32 and aren’t converted to 8-bit. For example, enable offloading for bigscience/bloom-1b7 through BitsAndBytesConfig.

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(llm_int8_enable_fp32_cpu_offload=True)

Design a custom device map to fit everything on your GPU except for the lm_head, which is dispatched to the CPU.

device_map = {
    "transformer.word_embeddings": 0,
    "transformer.word_embeddings_layernorm": 0,
    "lm_head": "cpu",
    "transformer.h": 0,
    "transformer.ln_f": 0,
}

Now load your model with the custom device_map and quantization_config.

model_8bit = AutoModelForCausalLM.from_pretrained(
    "bigscience/bloom-1b7",
    dtype="auto",
    device_map=device_map,
    quantization_config=quantization_config,
)

Outlier threshold

An “outlier” is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).

To find the best threshold for your model, experiment with the llm_int8_threshold parameter in BitsAndBytesConfig. For example, setting the threshold to 0.0 significantly speeds up inference at the potential cost of some accuracy loss.

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

model_id = "bigscience/bloom-1b7"

quantization_config = BitsAndBytesConfig(
    llm_int8_threshold=0.0,
    llm_int8_enable_fp32_cpu_offload=True
)

model_8bit = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="auto",
    device_map=device_map,
    quantization_config=quantization_config,
)

Skip module conversion

For some models, like Jukebox, you don’t need to quantize every module to 8-bit because it can actually cause instability. With Jukebox, there are several lm_head modules that should be skipped using the llm_int8_skip_modules parameter in BitsAndBytesConfig.

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_id = "bigscience/bloom-1b7"

quantization_config = BitsAndBytesConfig(
    llm_int8_skip_modules=["lm_head"],
)

model_8bit = AutoModelForCausalLM.from_pretrained(
    model_id,
    dtype="auto",
    device_map="auto",
    quantization_config=quantization_config,
)

Finetuning

The PEFT library supports fine-tuning large models like flan-t5-large and facebook/opt-6.7b with 8-bit quantization. You don’t need to pass the device_map parameter for training because it automatically loads your model on a GPU. However, you can still customize the device map with the device_map parameter (device_map="auto" should only be used for inference).

QLoRA

This section explores some of the specific features of 4-bit quantization, such as changing the compute data type, the Normal Float 4 (NF4) data type, and nested quantization.

Compute data type

Change the data type from float32 (the default value) to bf16 in BitsAndBytesConfig to speedup computation.

import torch
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

Normal Float 4 (NF4)

NF4 is a 4-bit data type from the QLoRA paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models.

from transformers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
)

model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", quantization_config=nf4_config)

For inference, the bnb_4bit_quant_type does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the bnb_4bit_compute_dtype and dtype values.

Nested quantization

Nested quantization can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. For example, with nested quantization, you can finetune a Llama-13b model on a 16GB NVIDIA T4 GPU with a sequence length of 1024, a batch size of 1, and enable gradient accumulation with 4 steps.

from transformers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", dtype="auto", quantization_config=double_quant_config)

Dequantizing bitsandbytes models

Once quantized, you can dequantize() a model to the original precision but this may result in some quality loss. Make sure you have enough GPU memory to fit the dequantized model.

from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", BitsAndBytesConfig(load_in_4bit=True))
model.dequantize()

Resources

Learn more about the details of 8-bit quantization in A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes.

Try 4-bit quantization in this notebook and learn more about it’s details in Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA.

Update on GitHub