nickfraser commited on
Commit
1b07a9d
·
1 Parent(s): ecec5b7

Feat (math model/tests): Updated math model and tests to match use format

Browse files
Files changed (3) hide show
  1. math_model.py +14 -10
  2. test_quant_conv2d.py +3 -1
  3. test_quant_linear.py +3 -1
math_model.py CHANGED
@@ -21,8 +21,10 @@ class QuantLinear(nn.Module):
21
  self.linear = nn.Linear(in_ch, out_ch)
22
  weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
23
  weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
 
24
  input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
25
  input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
 
26
  self.register_buffer('weight_scale', weight_scale)
27
  self.register_buffer('weight_zp', weight_zp)
28
  self.register_buffer('input_scale', input_scale)
@@ -31,9 +33,10 @@ class QuantLinear(nn.Module):
31
  # I.e., "fake quantization"
32
  def qdq_forward(self, x):
33
  scaled_x = x * self.mul_factor
34
- quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
 
35
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
36
- dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
37
  dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
38
  out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
39
  return out
@@ -47,12 +50,11 @@ class QuantLinear(nn.Module):
47
  # - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
48
  # - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
49
  # - All other code is just to make sure the broadcasting semantics work correctly
50
- weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32) # Conversion from uint8 -> int8, can be computed offline
51
- quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
52
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
53
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
54
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
55
- correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * weight_zp_int8.to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
56
  quant_output = quant_output - correction
57
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
58
  output += self.linear.bias
@@ -72,8 +74,10 @@ class QuantConv2d(nn.Module):
72
  self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
73
  weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
74
  weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
 
75
  input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
76
  input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
 
77
  self.register_buffer('weight_scale', weight_scale)
78
  self.register_buffer('weight_zp', weight_zp)
79
  self.register_buffer('input_scale', input_scale)
@@ -82,9 +86,10 @@ class QuantConv2d(nn.Module):
82
  # I.e., "fake quantization"
83
  def qdq_forward(self, x):
84
  scaled_x = x * self.mul_factor
85
- quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=True)
 
86
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
87
- dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
88
  dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
89
  out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
90
  return out
@@ -104,8 +109,7 @@ class QuantConv2d(nn.Module):
104
  # - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
105
  # - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
106
  # - All other code is just to make sure the broadcasting semantics work correctly
107
- weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32) # Conversion from uint8 -> int8, can be computed offline
108
- quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
109
  b_shape = list(quant_weight.shape) # Used for weight zero-point correction
110
  b_shape[0] = 1 # Used for weight zero-point correction
111
  weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) # Used for weight zero-point correction
@@ -113,7 +117,7 @@ class QuantConv2d(nn.Module):
113
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
114
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
115
  quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
116
- correction = quant_output[:,-1,:,:] * weight_zp_int8.to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
117
  quant_output = quant_output[:,:-1,:,:] - correction
118
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
119
  output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
 
21
  self.linear = nn.Linear(in_ch, out_ch)
22
  weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
23
  weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
24
+ assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}"
25
  input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
26
  input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
27
+ assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}"
28
  self.register_buffer('weight_scale', weight_scale)
29
  self.register_buffer('weight_zp', weight_zp)
30
  self.register_buffer('input_scale', input_scale)
 
33
  # I.e., "fake quantization"
34
  def qdq_forward(self, x):
35
  scaled_x = x * self.mul_factor
36
+ weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32)
37
+ quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_uint8, is_asym=True)
38
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
39
+ dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8)
40
  dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
41
  out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
42
  return out
 
50
  # - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
51
  # - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
52
  # - All other code is just to make sure the broadcasting semantics work correctly
53
+ quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8)
 
54
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
55
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
56
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
57
+ correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * self.weight_zp.to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
58
  quant_output = quant_output - correction
59
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
60
  output += self.linear.bias
 
74
  self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size)
75
  weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
76
  weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
77
+ assert quant_param['weight_zp_dtype'] == 'torch.int8', f"Weight Zero-Point dtype should be 'torch.int8', found: {quant_param['weight_zp_dype']}"
78
  input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
79
  input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
80
+ assert quant_param['input_zp_dtype'] == 'torch.int8', f"Input Zero-Point dtype should be 'torch.int8', found: {quant_param['input_zp_dype']}"
81
  self.register_buffer('weight_scale', weight_scale)
82
  self.register_buffer('weight_zp', weight_zp)
83
  self.register_buffer('input_scale', input_scale)
 
86
  # I.e., "fake quantization"
87
  def qdq_forward(self, x):
88
  scaled_x = x * self.mul_factor
89
+ weight_zp_uint8 = (self.weight_zp + 128).to(torch.uint8).to(torch.float32)
90
+ quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_uint8, is_asym=True)
91
  quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
92
+ dequantized_weight = dequantize(quant_weight, self.weight_scale, weight_zp_uint8)
93
  dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
94
  out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
95
  return out
 
109
  # - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
110
  # - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
111
  # - All other code is just to make sure the broadcasting semantics work correctly
112
+ quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=False).to(torch.int8)
 
113
  b_shape = list(quant_weight.shape) # Used for weight zero-point correction
114
  b_shape[0] = 1 # Used for weight zero-point correction
115
  weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) # Used for weight zero-point correction
 
117
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
118
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
119
  quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
120
+ correction = quant_output[:,-1,:,:] * self.weight_zp.to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
121
  quant_output = quant_output[:,:-1,:,:] - correction
122
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
123
  output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
test_quant_conv2d.py CHANGED
@@ -20,12 +20,14 @@ quant_params = {
20
  'weight_scale': torch.rand((out_ch,)),
21
  'weight_scale': torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values / 128.,
22
  'weight_scale_shape': (out_ch,1,1,1),
23
- 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=(1,2,3))) * (128 / torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values)) + 128, 0, 255),
24
  'weight_zp_shape': (out_ch,1,1,1),
 
25
  'input_scale': torch.max(torch.abs(i)) / 128.,
26
  'input_scale_shape': tuple(),
27
  'input_zp': torch.zeros((1,)),
28
  'input_zp_shape': tuple(),
 
29
  }
30
 
31
  print(quant_params)
 
20
  'weight_scale': torch.rand((out_ch,)),
21
  'weight_scale': torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values / 128.,
22
  'weight_scale_shape': (out_ch,1,1,1),
23
+ 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=(1,2,3))) * (128 / torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values)), -128, 127),
24
  'weight_zp_shape': (out_ch,1,1,1),
25
+ 'weight_zp_dtype': 'torch.int8',
26
  'input_scale': torch.max(torch.abs(i)) / 128.,
27
  'input_scale_shape': tuple(),
28
  'input_zp': torch.zeros((1,)),
29
  'input_zp_shape': tuple(),
30
+ 'input_zp_dtype': 'torch.int8',
31
  }
32
 
33
  print(quant_params)
test_quant_linear.py CHANGED
@@ -16,12 +16,14 @@ quant_params = {
16
  'smoothquant_mul_shape': (1,in_ch),
17
  'weight_scale': torch.max(torch.abs(l.weight), dim=1).values / 128.,
18
  'weight_scale_shape': (out_ch,1),
19
- 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=1)) * (128 / torch.max(torch.abs(l.weight), dim=1).values)) + 128, 0, 255),
20
  'weight_zp_shape': (out_ch,1),
 
21
  'input_scale': torch.max(torch.abs(i)) / 128.,
22
  'input_scale_shape': tuple(),
23
  'input_zp': torch.zeros((1,)),
24
  'input_zp_shape': tuple(),
 
25
  }
26
 
27
  print(quant_params)
 
16
  'smoothquant_mul_shape': (1,in_ch),
17
  'weight_scale': torch.max(torch.abs(l.weight), dim=1).values / 128.,
18
  'weight_scale_shape': (out_ch,1),
19
+ 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=1)) * (128 / torch.max(torch.abs(l.weight), dim=1).values)), -128, 127),
20
  'weight_zp_shape': (out_ch,1),
21
+ 'weight_zp_dtype': 'torch.int8',
22
  'input_scale': torch.max(torch.abs(i)) / 128.,
23
  'input_scale_shape': tuple(),
24
  'input_zp': torch.zeros((1,)),
25
  'input_zp_shape': tuple(),
26
+ 'input_zp_dtype': 'torch.int8',
27
  }
28
 
29
  print(quant_params)