Charlie81 commited on
Commit
d5537e3
·
1 Parent(s): 62fa95e

patch quantize

Browse files
Files changed (1) hide show
  1. quantization.py +11 -7
quantization.py CHANGED
@@ -80,22 +80,26 @@ class MixedPrecisionQuantizer:
80
  if bits == 32:
81
  return layer
82
 
83
- weight = layer.weight.data
84
- bias = layer.bias.data if layer.bias is not None else None
85
 
86
  # Symmetric quantization
87
  qmin = -(2 ** (bits - 1))
88
  qmax = 2 ** (bits - 1) - 1
89
 
90
- # Calculate scale
91
- max_val = torch.max(torch.abs(weight))
 
 
92
  scale = max_val / qmax
93
 
94
- # Quantize
95
  weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
 
96
 
97
- # Store quantized weights and scale
98
- layer.weight.data = weight_q.to(torch.int8 if bits <= 8 else torch.int16)
 
 
99
  layer.weight_scale = scale
100
  layer.quantized = True
101
  layer.bits = bits
 
80
  if bits == 32:
81
  return layer
82
 
83
+ weight = layer.weight.data.clone()
 
84
 
85
  # Symmetric quantization
86
  qmin = -(2 ** (bits - 1))
87
  qmax = 2 ** (bits - 1) - 1
88
 
89
+ # Calculate scale per-channel (per output channel)
90
+ # This provides better accuracy than per-tensor quantization
91
+ max_val = torch.max(torch.abs(weight), dim=1, keepdim=True)[0]
92
+ max_val = torch.clamp(max_val, min=1e-5) # Avoid division by zero
93
  scale = max_val / qmax
94
 
95
+ # Quantize and dequantize (fake quantization)
96
  weight_q = torch.clamp(torch.round(weight / scale), qmin, qmax)
97
+ weight_dq = weight_q * scale
98
 
99
+ # Store dequantized weights as float (required for autograd)
100
+ layer.weight.data = weight_dq.contiguous()
101
+
102
+ # Store quantization metadata as layer attributes
103
  layer.weight_scale = scale
104
  layer.quantized = True
105
  layer.bits = bits