Bitsandbytes documentation
Quantization primitives
Quantization primitives
Below you will find the docstring of the quantization primitives exposed in bitsandbytes.
Linear4bit (QLoRA)
class bitsandbytes.nn.Linear4bit
< source >( input_features output_features bias = True compute_dtype = None compress_statistics = True quant_type = 'fp4' quant_storage = torch.uint8 device = None )
This class is the base module for the 4-bit quantization algorithm presented in QLoRA. QLoRA 4-bit linear layers uses blockwise k-bit quantization under the hood, with the possibility of selecting various compute datatypes such as FP4 and NF4.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear4bit module, then call quantized_module.to("cuda")
to quantize the fp16 / bf16 weights.
Example:
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear4bit
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
quantized_model = nn.Sequential(
Linear4bit(64, 64),
Linear4bit(64, 64)
)
quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to(0) # Quantization happens here
__init__
< source >( input_features output_features bias = True compute_dtype = None compress_statistics = True quant_type = 'fp4' quant_storage = torch.uint8 device = None )
Initialize Linear4bit class.
Linear8bitLt
class bitsandbytes.nn.Linear8bitLt
< source >( input_features output_features bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )
This class is the base module for the LLM.int8() algorithm. To read more about it, have a look at the paper.
In order to quantize a linear layer one should first load the original fp16 / bf16 weights into
the Linear8bitLt module, then call int8_module.to("cuda")
to quantize the fp16 weights.
Example:
import torch
import torch.nn as nn
import bitsandbytes as bnb
from bnb.nn import Linear8bitLt
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
int8_model = nn.Sequential(
Linear8bitLt(64, 64, has_fp16_weights=False),
Linear8bitLt(64, 64, has_fp16_weights=False)
)
int8_model.load_state_dict(fp16_model.state_dict())
int8_model = int8_model.to(0) # Quantization happens here
__init__
< source >( input_features output_features bias = True has_fp16_weights = True memory_efficient_backward = False threshold = 0.0 index = None device = None )
Initialize Linear8bitLt class.