This code from BitLinear doesn't make sense

#7
by qmsoqm - opened
    def forward(self, input):
        
        quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
        quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()

        out = nn.functional.linear(quant_input, quant_weight)
        if not self.bias is None:
            out += self.bias.view(1, -1).expand_as(out)

        return out

First, adding and deducting self.weight for quant_weight is unnecessary since there is no discounting factor. Since it uses detach() method, it'll take up more memory.

Second, why need to get quant_weight like this? Why not keep weight as quatized({-1,0,1}) to begin with?

IIUC

  1. Detach is needed for backpropagation. It enables Straight-Through Estimation. The gradients flow directly to the inputs rather than the non-differentiable quant ops.
  2. In Quantize Aware Training, the goal is to introduce quantization loss during forward pass and let the model learn to become robust to the introduced quantization loss. That's why there is quant and de-quant step. During inference the weights will be {-1, 0, 1}.

Sign up or log in to comment