bitLinear-phi-1.5 / bitlinear.py
Mrw33554432's picture
Update bitlinear.py
9dd1412 verified
raw
history blame
1.03 kB
import torch.nn.functional as F
from torch import Tensor, nn
def weight_quant(w):
"""
from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf,
This is a little bit different from paper by adding '/ scale' in the end, as released by the paper author.
which is super crucial for training (7.5 loss vs 2.5).
"""
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u
class BitLinear(nn.Linear):
"""
A modified version of bit linear, only apply bit quant to weight.
"""
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the BitLinear layer, applying quantization to weights.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
w = self.weight
w_quant = w + (weight_quant(w) - w).detach() # Apply quantization adjustments
return F.linear(x, w_quant, self.bias)