Bitsandbytes documentation
Overview
Overview
The bitsandbytes.functional API provides the low-level building blocks for the library’s features.
When to Use bitsandbytes.functional
- When you need direct control over quantized operations and their parameters.
- To build custom layers or operations leveraging low-bit arithmetic.
- To integrate with other ecosystem tooling.
- For experimental or research purposes requiring non-standard quantization or performance optimizations.
LLM.int8()
bitsandbytes.functional.int8_linear_matmul
< source >( A: Tensor B: Tensor out: typing.Optional[torch.Tensor] = None dtype = torch.int32 ) → torch.Tensor
Parameters
- A (
torch.Tensor) — The first matrix operand with the data typetorch.int8. - B (
torch.Tensor) — The second matrix operand with the data typetorch.int8. - out (
torch.Tensor, optional) — A pre-allocated tensor used to store the result. - dtype (
torch.dtype, optional) — The expected data type of the output. Defaults totorch.int32.
Returns
torch.Tensor
The result of the operation.
Raises
NotImplementedError or RuntimeError
NotImplementedError— The operation is not supported in the current environment.RuntimeError— Raised when the cannot be completed for any other reason.
Performs an 8-bit integer matrix multiplication.
A linear transformation is applied such that out = A @ B.T. When possible, integer tensor core hardware is
utilized to accelerate the operation.
bitsandbytes.functional.int8_mm_dequant
< source >( A: Tensor row_stats: Tensor col_stats: Tensor out: typing.Optional[torch.Tensor] = None bias: typing.Optional[torch.Tensor] = None ) → torch.Tensor
Parameters
- A (
torch.Tensorwith dtypetorch.int32) — The result of a quantized int8 matrix multiplication. - row_stats (
torch.Tensor) — The row-wise quantization statistics for the lhs operand of the matrix multiplication. - col_stats (
torch.Tensor) — The column-wise quantization statistics for the rhs operand of the matrix multiplication. - out (
torch.Tensor, optional) — A pre-allocated tensor to store the output of the operation. - bias (
torch.Tensor, optional) — An optional bias vector to add to the result.
Returns
torch.Tensor
The dequantized result with an optional bias, with dtype torch.float16.
Performs dequantization on the result of a quantized int8 matrix multiplication.
bitsandbytes.functional.int8_vectorwise_dequant
< source >( A: Tensor stats: Tensor ) → torch.Tensor with dtype torch.float32
Dequantizes a tensor with dtype torch.int8 to torch.float32.
bitsandbytes.functional.int8_vectorwise_quant
< source >( A: Tensor threshold = 0.0 ) → Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
Parameters
- A (
torch.Tensorwith dtypetorch.float16) — The input tensor. - threshold (
float, optional) — An optional threshold for sparse decomposition of outlier features.No outliers are held back when 0.0. Defaults to 0.0.
Returns
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
A tuple containing the quantized tensor and relevant statistics.
torch.Tensorwith dtypetorch.int8: The quantized data.torch.Tensorwith dtypetorch.float32: The quantization scales.torch.Tensorwith dtypetorch.int32, optional: A list of column indices which contain outlier features.
Quantizes a tensor with dtype torch.float16 to torch.int8 in accordance to the LLM.int8() algorithm.
For more information, see the LLM.int8() paper.
4-bit
bitsandbytes.functional.dequantize_4bit
< source >( A: Tensor quant_state: typing.Optional[bitsandbytes.functional.QuantState] = None absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize: typing.Optional[int] = None quant_type = 'fp4' ) → torch.Tensor
Parameters
- A (
torch.Tensor) — The quantized input tensor. - quant_state (
QuantState, optional) — The quantization state as returned byquantize_4bit. Required ifabsmaxis not provided. - absmax (
torch.Tensor, optional) — A tensor containing the scaling values. Required ifquant_stateis not provided and ignored otherwise. - out (
torch.Tensor, optional) — A tensor to use to store the result. - blocksize (
int, optional) — The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - quant_type (
str, optional) — The data type to use:nf4orfp4. Defaults tofp4.
Returns
torch.Tensor
The dequantized tensor.
Raises
ValueError
ValueError— Raised when the input data type or blocksize is not supported.
Dequantizes a packed 4-bit quantized tensor.
The input tensor is dequantized by dividing it into blocks of blocksize values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
bitsandbytes.functional.dequantize_fp4
< source >( A: Tensor quant_state: typing.Optional[bitsandbytes.functional.QuantState] = None absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize: typing.Optional[int] = None )
bitsandbytes.functional.dequantize_nf4
< source >( A: Tensor quant_state: typing.Optional[bitsandbytes.functional.QuantState] = None absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize: typing.Optional[int] = None )
bitsandbytes.functional.gemv_4bit
< source >( A: Tensor B: Tensor out: typing.Optional[torch.Tensor] = None transposed_A = False transposed_B = False state = None )
bitsandbytes.functional.quantize_4bit
< source >( A: Tensor absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize = None compress_statistics = False quant_type = 'fp4' quant_storage = torch.uint8 ) → Tuple[torch.Tensor, QuantState]
Parameters
- A (
torch.Tensor) — The input tensor. Supportsfloat16,bfloat16, orfloat32datatypes. - absmax (
torch.Tensor, optional) — A tensor to use to store the absmax values. - out (
torch.Tensor, optional) — A tensor to use to store the result. - blocksize (
int, optional) — The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - compress_statistics (
bool, optional) — Whether to additionally quantize the absmax values. Defaults to False. - quant_type (
str, optional) — The data type to use:nf4orfp4. Defaults tofp4. - quant_storage (
torch.dtype, optional) — The dtype of the tensor used to store the result. Defaults totorch.uint8.
Returns
Tuple[torch.Tensor, QuantState]
A tuple containing the quantization results.
torch.Tensor: The quantized tensor with packed 4-bit values.QuantState: The state object used to undo the quantization.
Raises
ValueError
ValueError— Raised when the input data type is not supported.
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized.
bitsandbytes.functional.quantize_fp4
< source >( A: Tensor absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize = None compress_statistics = False quant_storage = torch.uint8 )
bitsandbytes.functional.quantize_nf4
< source >( A: Tensor absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize = None compress_statistics = False quant_storage = torch.uint8 )
class bitsandbytes.functional.QuantState
< source >( absmax shape = None code = None blocksize = None quant_type = None dtype = None offset = None state2 = None )
container for quantization state components to work with Params4bit and similar classes
returns dict of tensors and strings to use in serialization via _save_to_state_dict() param: packed — returns dict[str, torch.Tensor] for state_dict fit for safetensors saving
unpacks components of state_dict into QuantState where necessary, convert into strings, torch.dtype, ints, etc.
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key quant_state.bitsandbytes__[nf4/fp4] may contain minor and non-tensor quant state items.
Dynamic 8-bit Quantization
Primitives used in the 8-bit optimizer quantization.
For more details see 8-Bit Approximations for Parallelism in Deep Learning
bitsandbytes.functional.dequantize_blockwise
< source >( A: Tensor quant_state: typing.Optional[bitsandbytes.functional.QuantState] = None absmax: typing.Optional[torch.Tensor] = None code: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize: int = 4096 nested = False ) → torch.Tensor
Parameters
- A (
torch.Tensor) — The quantized input tensor. - quant_state (
QuantState, optional) — The quantization state as returned byquantize_blockwise. Required ifabsmaxis not provided. - absmax (
torch.Tensor, optional) — A tensor containing the scaling values. Required ifquant_stateis not provided and ignored otherwise. - code (
torch.Tensor, optional) — A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. Ignored whenquant_stateis provided. - out (
torch.Tensor, optional) — A tensor to use to store the result. - blocksize (
int, optional) — The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Ignored whenquant_stateis provided.
Returns
torch.Tensor
The dequantized tensor. The datatype is indicated by quant_state.dtype and defaults to torch.float32.
Raises
ValueError
ValueError— Raised when the input data type is not supported.
Dequantize a tensor in blocks of values.
The input tensor is dequantized by dividing it into blocks of blocksize values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
bitsandbytes.functional.quantize_blockwise
< source >( A: Tensor code: typing.Optional[torch.Tensor] = None absmax: typing.Optional[torch.Tensor] = None out: typing.Optional[torch.Tensor] = None blocksize = 4096 nested = False ) → Tuple[torch.Tensor, QuantState]
Parameters
- A (
torch.Tensor) — The input tensor. Supportsfloat16,bfloat16, orfloat32datatypes. - code (
torch.Tensor, optional) — A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type. For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]. - absmax (
torch.Tensor, optional) — A tensor to use to store the absmax values. - out (
torch.Tensor, optional) — A tensor to use to store the result. - blocksize (
int, optional) — The size of the blocks. Defaults to 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. - nested (
bool, optional) — Whether to additionally quantize the absmax values. Defaults to False.
Returns
Tuple[torch.Tensor, QuantState]
A tuple containing the quantization results.
torch.Tensor: The quantized tensor.QuantState: The state object used to undo the quantization.
Raises
ValueError
ValueError— Raised when the input data type is not supported.
Quantize a tensor in blocks of values.
The input tensor is quantized by dividing it into blocks of blocksize values.
The the absolute maximum value within these blocks is calculated for scaling
the non-linear quantization.
Utility
bitsandbytes.functional.get_ptr
< source >( A: typing.Optional[torch.Tensor] ) → Optional[ct.c_void_p]
Gets the memory address of the first element of a tenso