File size: 929 Bytes
db8a935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch

class QLinear(torch.nn.Module):
    def __init__(self, bits: int, weight: torch.Tensor, bias=None):
        super().__init__()
        self.quant_bits = bits
        if self.quant_bits != 8:
            raise ValueError(
                f'Only supprt int8 quant in current version'
            )
        self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
        self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8)
        self.weight = self.weight.T
        self.bias = None

    def forward(self, input):
        if self.weight.device != input.device:
            self.weight = self.weight.to(input.device)
            self.scale = self.scale.to(input.device)
        
        output = torch.matmul(input, self.weight.to(input.dtype)) * self.scale.to(input.dtype)[None,None, :]
        if self.bias is not None:
            output = output + self.bias
        return output