|
|
|
|
|
import math |
|
|
|
import mlx.core as mx |
|
from mlx.nn.layers.base import Module |
|
from mlx.nn.layers.linear import Linear |
|
from mlx.utils import tree_flatten, tree_map |
|
|
|
|
|
class QuantizedLinear(Module): |
|
"""Applies an affine transformation to the input using a quantized weight matrix. |
|
|
|
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its |
|
parameters are frozen and will not be included in any gradient computation |
|
but this will probably change in the future. |
|
|
|
QuantizedLinear also provides two useful classmethods to convert linear |
|
layers to QuantizedLinear layers. |
|
|
|
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same |
|
linear transformation up to the quantization error. |
|
- :meth:`quantize_module` swaps all the linear layers of the passed module |
|
with QuantizedLinear ones. |
|
|
|
Args: |
|
input_dims (int): The dimensionality of the input features |
|
output_dims (int): The dimensionality of the output features |
|
bias (bool, optional): If set to ``False`` then the layer will not use |
|
a bias. (default: True). |
|
group_size (int, optional): The group size to use for the quantized |
|
weight. See :func:`~mlx.core.quantize`. (default: 64) |
|
bits (int, optional): The bit width to use for the quantized weight. |
|
See :func:`~mlx.core.quantize`. (default: 4) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_dims: int, |
|
output_dims: int, |
|
bias: bool = True, |
|
group_size: int = 64, |
|
bits: int = 4, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.group_size = group_size |
|
self.bits = bits |
|
|
|
|
|
scale = math.sqrt(1 / input_dims) |
|
weight = mx.random.uniform( |
|
low=-scale, |
|
high=scale, |
|
shape=(output_dims, input_dims), |
|
) |
|
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) |
|
|
|
|
|
if bias: |
|
self.bias = mx.zeros((output_dims,)) |
|
|
|
|
|
self.freeze() |
|
|
|
def unfreeze(self, *args, **kwargs): |
|
"""Wrap unfreeze so that we unfreeze any layers we might contain but |
|
our parameters will remain frozen.""" |
|
super().unfreeze(*args, **kwargs) |
|
self.freeze(recurse=False) |
|
|
|
def _extra_repr(self): |
|
out_dims, in_dims = self.weight.shape |
|
in_dims *= 32 // self.bits |
|
return ( |
|
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}," |
|
f"group_size={self.group_size}, bits={self.bits}" |
|
) |
|
|
|
def __call__(self, x): |
|
x = mx.quantized_matmul( |
|
x, |
|
self.weight, |
|
scales=self.scales, |
|
biases=self.biases, |
|
transpose=True, |
|
group_size=self.group_size, |
|
bits=self.bits, |
|
) |
|
if "bias" in self: |
|
x = x + self.bias |
|
return x |
|
|
|
@classmethod |
|
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): |
|
"""Create a QuantizedLinear layer from the parameters of a provided |
|
linear layer.""" |
|
output_dims, input_dims = linear_layer.weight.shape |
|
ql = cls(input_dims, output_dims, False, group_size, bits) |
|
ql.weight, ql.scales, ql.biases = mx.quantize( |
|
linear_layer.weight, group_size, bits |
|
) |
|
if "bias" in linear_layer: |
|
ql.bias = linear_layer.bias |
|
|
|
return ql |
|
|
|
@classmethod |
|
def quantize_module( |
|
cls, |
|
model: Module, |
|
group_size: int = 64, |
|
bits: int = 4, |
|
linear_class_predicate=lambda m: isinstance(m, Linear), |
|
): |
|
def _quantize_if_linear(m): |
|
if linear_class_predicate(m): |
|
return cls.from_linear(m, group_size, bits) |
|
else: |
|
return m |
|
|
|
leaves = model.leaf_modules() |
|
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module) |
|
model.update_modules(leaves) |
|
|