patch quantize
Browse files- quantization.py +11 -7
quantization.py
CHANGED
|
@@ -80,22 +80,26 @@ class MixedPrecisionQuantizer:
|
|
| 80 |
if bits == 32:
|
| 81 |
return layer
|
| 82 |
|
| 83 |
-
weight = layer.weight.data
|
| 84 |
-
bias = layer.bias.data if layer.bias is not None else None
|
| 85 |
|
| 86 |
# Symmetric quantization
|
| 87 |
qmin = -(2 ** (bits - 1))
|
| 88 |
qmax = 2 ** (bits - 1) - 1
|
| 89 |
|
| 90 |
-
# Calculate scale
|
| 91 |
-
|
|
|
|
|
|
|
| 92 |
scale = max_val / qmax
|
| 93 |
|
| 94 |
-
# Quantize
|
| 95 |
weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
|
|
|
|
| 96 |
|
| 97 |
-
# Store
|
| 98 |
-
layer.weight.data =
|
|
|
|
|
|
|
| 99 |
layer.weight_scale = scale
|
| 100 |
layer.quantized = True
|
| 101 |
layer.bits = bits
|
|
|
|
| 80 |
if bits == 32:
|
| 81 |
return layer
|
| 82 |
|
| 83 |
+
weight = layer.weight.data.clone()
|
|
|
|
| 84 |
|
| 85 |
# Symmetric quantization
|
| 86 |
qmin = -(2 ** (bits - 1))
|
| 87 |
qmax = 2 ** (bits - 1) - 1
|
| 88 |
|
| 89 |
+
# Calculate scale per-channel (per output channel)
|
| 90 |
+
# This provides better accuracy than per-tensor quantization
|
| 91 |
+
max_val = torch.max(torch.abs(weight), dim=1, keepdim=True)[0]
|
| 92 |
+
max_val = torch.clamp(max_val, min=1e-5) # Avoid division by zero
|
| 93 |
scale = max_val / qmax
|
| 94 |
|
| 95 |
+
# Quantize and dequantize (fake quantization)
|
| 96 |
weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
|
| 97 |
+
weight_dq = weight_q * scale
|
| 98 |
|
| 99 |
+
# Store dequantized weights as float (required for autograd)
|
| 100 |
+
layer.weight.data = weight_dq.contiguous()
|
| 101 |
+
|
| 102 |
+
# Store quantization metadata as layer attributes
|
| 103 |
layer.weight_scale = scale
|
| 104 |
layer.quantized = True
|
| 105 |
layer.bits = bits
|