reach-vb's picture
reach-vb HF staff
ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc
f14e74e
raw
history blame
4.14 kB
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.linear import Linear
from mlx.utils import tree_flatten, tree_map
class QuantizedLinear(Module):
"""Applies an affine transformation to the input using a quantized weight matrix.
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
parameters are frozen and will not be included in any gradient computation
but this will probably change in the future.
QuantizedLinear also provides two useful classmethods to convert linear
layers to QuantizedLinear layers.
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
linear transformation up to the quantization error.
- :meth:`quantize_module` swaps all the linear layers of the passed module
with QuantizedLinear ones.
Args:
input_dims (int): The dimensionality of the input features
output_dims (int): The dimensionality of the output features
bias (bool, optional): If set to ``False`` then the layer will not use
a bias. (default: True).
group_size (int, optional): The group size to use for the quantized
weight. See :func:`~mlx.core.quantize`. (default: 64)
bits (int, optional): The bit width to use for the quantized weight.
See :func:`~mlx.core.quantize`. (default: 4)
"""
def __init__(
self,
input_dims: int,
output_dims: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input_dims),
)
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
# And bias if needed
if bias:
self.bias = mx.zeros((output_dims,))
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
def _extra_repr(self):
out_dims, in_dims = self.weight.shape
in_dims *= 32 // self.bits
return (
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
f"group_size={self.group_size}, bits={self.bits}"
)
def __call__(self, x):
x = mx.quantized_matmul(
x,
self.weight,
scales=self.scales,
biases=self.biases,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + self.bias
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
"""Create a QuantizedLinear layer from the parameters of a provided
linear layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias
return ql
@classmethod
def quantize_module(
cls,
model: Module,
group_size: int = 64,
bits: int = 4,
linear_class_predicate=lambda m: isinstance(m, Linear),
):
def _quantize_if_linear(m):
if linear_class_predicate(m):
return cls.from_linear(m, group_size, bits)
else:
return m
leaves = model.leaf_modules()
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
model.update_modules(leaves)