WCNegentropy commited on
Commit
86d8d49
·
verified ·
1 Parent(s): 65b2ef0

Remove nested directory: BitTransformerLM/bit_transformer/quantization.py

Browse files
BitTransformerLM/bit_transformer/quantization.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from torch.ao.quantization.fake_quantize import FakeQuantize
5
- from torch.ao.quantization.observer import MinMaxObserver
6
- from torch.ao.quantization.qconfig import QConfig
7
- from torch.ao.quantization import convert
8
-
9
- from .model import BitTransformerLM
10
-
11
-
12
- def quantize_dynamic(model: BitTransformerLM, dtype: torch.dtype = torch.qint8) -> BitTransformerLM:
13
- """Return a dynamically quantized copy of the model for inference."""
14
- quantized = torch.quantization.quantize_dynamic(
15
- model, {nn.Linear}, dtype=dtype
16
- )
17
- return quantized
18
-
19
-
20
- class FourBitObserver(MinMaxObserver):
21
- """Min-max observer configured for 4-bit quantization."""
22
-
23
- def __init__(self, **kwargs):
24
- super().__init__(
25
- quant_min=0,
26
- quant_max=15,
27
- dtype=torch.quint8,
28
- qscheme=torch.per_tensor_affine,
29
- **kwargs,
30
- )
31
-
32
-
33
- FourBitFakeQuantize = FakeQuantize.with_args(observer=FourBitObserver)
34
-
35
- four_bit_qconfig = QConfig(activation=FourBitFakeQuantize, weight=FourBitFakeQuantize)
36
-
37
-
38
- class QATLinear(nn.Linear):
39
- """Linear layer with fake quantization for QAT."""
40
-
41
- def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
42
- super().__init__(in_features, out_features, bias)
43
- self.weight_fake_quant = FourBitFakeQuantize()
44
- self.activation_post_process = FourBitFakeQuantize()
45
-
46
- @classmethod
47
- def from_float(cls, mod: nn.Linear) -> "QATLinear":
48
- qat = cls(mod.in_features, mod.out_features, mod.bias is not None)
49
- qat.weight = mod.weight
50
- qat.bias = mod.bias
51
- return qat
52
-
53
- def forward(self, x: torch.Tensor) -> torch.Tensor:
54
- x = self.activation_post_process(x)
55
- w = self.weight_fake_quant(self.weight)
56
- return nn.functional.linear(x, w, self.bias)
57
-
58
-
59
- def prepare_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
60
- """Prepare BitTransformerLM for quantization-aware training."""
61
-
62
- for name, module in model.named_children():
63
- if isinstance(module, nn.Linear):
64
- setattr(model, name, QATLinear.from_float(module))
65
- else:
66
- prepare_qat_fx(module)
67
- return model
68
-
69
-
70
- def convert_qat_fx(model: BitTransformerLM) -> BitTransformerLM:
71
- """Convert a QAT-prepared model to a quantized version."""
72
-
73
- for name, module in model.named_children():
74
- if isinstance(module, QATLinear):
75
- w = module.weight.data
76
- qmin, qmax = 0, 15
77
- min_w = w.min()
78
- max_w = w.max()
79
- scale = (max_w - min_w) / (qmax - qmin + 1e-8)
80
- zero_point = qmin - torch.round(min_w / scale)
81
- q_w = torch.clamp(torch.round(w / scale + zero_point), qmin, qmax)
82
- new_mod = nn.Linear(module.in_features, module.out_features, module.bias is not None)
83
- new_mod.weight = nn.Parameter((q_w - zero_point) * scale)
84
- if module.bias is not None:
85
- new_mod.bias = nn.Parameter(module.bias.data)
86
- setattr(model, name, new_mod)
87
- else:
88
- convert_qat_fx(module)
89
- return model