File size: 1,031 Bytes
89f541f
 
 
 
 
 
 
9dd1412
 
89f541f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn.functional as F
from torch import Tensor, nn


def weight_quant(w):
    """
    from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf,
    This is a little bit different from paper by adding '/ scale' in the end, as released by the paper author.
    which is super crucial for training (7.5 loss vs 2.5).
    """
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u


class BitLinear(nn.Linear):
    """
    A modified version of bit linear, only apply bit quant to weight.
    """

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the BitLinear layer, applying quantization to weights.
        Args:
            x (Tensor): The input tensor.
        Returns:
            Tensor: The output tensor.
        """
        w = self.weight
        w_quant = w + (weight_quant(w) - w).detach()  # Apply quantization adjustments
        return F.linear(x, w_quant, self.bias)