Why 4bit quantised performance is slower than fp 16?

#17
by kapil1611 - opened

I am wrapping my head around

Trying to understand Why A is faster than B

A.

tokenizer_large = AutoTokenizer.from_pretrained(f"google/flan-t5-large")
model_large = AutoModelForSeq2SeqLM.from_pretrained(f"google/flan-t5-large", torch_dtype=torch.float16, device_map="auto")
IS FASTER THEN

B.

model_id = "google/flan-t5-large"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False
)
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer_large = AutoTokenizer.from_pretrained(model_id) (edited)

Google org

Hi @kapil1611

load_in_4bit flag activates the 4bit quantization described in this paper: https://arxiv.org/abs/2305.14314 - that method iteratively quantizes and de-quantizes linear layers in 4bit and makes the matmul computation either in float32 (default) or half precision. The quantization / de-quantization adds some overhead, making it slower in most cases compared to half-precision models.

By default we use bnb_4bit_compute_dtype=torch.float32: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L204
For faster generation, you can benefit from the optimized kernels described here: https://twitter.com/Tim_Dettmers/status/1683118705956491264?s=20 - first make sure to use the latest stable bitsandbytes package pip install -U bitsandbytes, then run:

import torch
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer

model_id = "google/flan-t5-large"
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.float16
)
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer_large = AutoTokenizer.from_pretrained(model_id)

That should hopefully lead to much faster inference speed compared to default 4bit models, and maybe similar or faster inference speed with batch_size=1 depending on the hardware

Sign up or log in to comment