Mrw33554432 commited on
Commit
89f541f
1 Parent(s): d013df6

Upload 2 files

Browse files
Files changed (2) hide show
  1. bitlinear.py +31 -0
  2. replace_hf.py +49 -0
bitlinear.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import Tensor, nn
3
+
4
+
5
+ def weight_quant(w):
6
+ """
7
+ from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf,
8
+ This is a little bit different from paper by adding '/ scale' in the end,
9
+ which is super crucial for training (7.5 loss vs 2.5)
10
+ """
11
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
12
+ u = (w * scale).round().clamp_(-1, 1) / scale
13
+ return u
14
+
15
+
16
+ class BitLinear(nn.Linear):
17
+ """
18
+ A modified version of bit linear, only apply bit quant to weight.
19
+ """
20
+
21
+ def forward(self, x: Tensor) -> Tensor:
22
+ """
23
+ Forward pass of the BitLinear layer, applying quantization to weights.
24
+ Args:
25
+ x (Tensor): The input tensor.
26
+ Returns:
27
+ Tensor: The output tensor.
28
+ """
29
+ w = self.weight
30
+ w_quant = w + (weight_quant(w) - w).detach() # Apply quantization adjustments
31
+ return F.linear(x, w_quant, self.bias)
replace_hf.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from bitlinear import BitLinear
7
+
8
+
9
+ # Adapt from https://github.com/kyegomez/BitNet/blob/main/bitnet/replace_hf.py
10
+ def replace_linear_in_hf(model, keep_param: bool):
11
+ """
12
+ Replaces all instances of nn.Linear in the given model with BitLinear, except lm_head.
13
+
14
+ Args:
15
+ model (nn.Module): The model to modify.
16
+
17
+ Returns:
18
+ None
19
+ :param model: The model to modify.
20
+ :param keep_param: if ture, the model will keep param from the initial model.
21
+ if false, the model will be using random init weight (For training)
22
+ """
23
+ for name, module in model.named_children():
24
+ if isinstance(module, nn.Linear):
25
+ if 'head' in name:
26
+ continue
27
+ # Create a new BitLinear layer with random parameters
28
+ bit_linear = BitLinear(
29
+ in_features=module.in_features,
30
+ out_features=module.out_features,
31
+ bias=module.bias is not None,
32
+ )
33
+
34
+ if keep_param:
35
+ # Transfer the weights and bias from the original nn.Linear to the new BitLinear
36
+ bit_linear.weight.data.copy_(module.weight.data)
37
+ if module.bias is not None:
38
+ bit_linear.bias.data.copy_(module.bias.data)
39
+
40
+ del module
41
+
42
+ # Replace the nn.Linear with the new BitLinear
43
+ setattr(model, name, bit_linear)
44
+ else:
45
+ # Recursively apply to child modules
46
+ replace_linear_in_hf(module, keep_param)
47
+ gc.collect()
48
+ if torch.cuda.is_available():
49
+ torch.cuda.empty_cache()