duzx16 commited on
Commit ·
f433128
1
Parent(s): aee628d
Add logger
Browse filesFix backward for W8A16LinearCPU
- quantization.py +4 -2
quantization.py
CHANGED
|
@@ -12,6 +12,8 @@ from transformers.utils import logging
|
|
| 12 |
from typing import List
|
| 13 |
from functools import partial
|
| 14 |
|
|
|
|
|
|
|
| 15 |
try:
|
| 16 |
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
| 17 |
|
|
@@ -68,11 +70,11 @@ class W8A16LinearCPU(torch.autograd.Function):
|
|
| 68 |
@staticmethod
|
| 69 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
|
| 70 |
ctx.inp_shape = inp.size()
|
| 71 |
-
ctx.weight_shape = quant_w.size()
|
| 72 |
ctx.weight_bit_width = weight_bit_width
|
| 73 |
out_features = quant_w.size(0)
|
| 74 |
inp = inp.contiguous().view(-1, inp.size(-1))
|
| 75 |
weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
|
|
|
|
| 76 |
output = inp.mm(weight.t())
|
| 77 |
ctx.save_for_backward(inp, quant_w, scale_w)
|
| 78 |
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
|
@@ -84,7 +86,7 @@ class W8A16LinearCPU(torch.autograd.Function):
|
|
| 84 |
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
| 85 |
grad_input = grad_output.mm(weight)
|
| 86 |
grad_weight = grad_output.t().mm(inp)
|
| 87 |
-
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
|
| 88 |
|
| 89 |
|
| 90 |
default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
|
|
|
|
| 12 |
from typing import List
|
| 13 |
from functools import partial
|
| 14 |
|
| 15 |
+
logger = logging.get_logger(__name__)
|
| 16 |
+
|
| 17 |
try:
|
| 18 |
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
|
| 19 |
|
|
|
|
| 70 |
@staticmethod
|
| 71 |
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
|
| 72 |
ctx.inp_shape = inp.size()
|
|
|
|
| 73 |
ctx.weight_bit_width = weight_bit_width
|
| 74 |
out_features = quant_w.size(0)
|
| 75 |
inp = inp.contiguous().view(-1, inp.size(-1))
|
| 76 |
weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
|
| 77 |
+
ctx.weight_shape = weight.size()
|
| 78 |
output = inp.mm(weight.t())
|
| 79 |
ctx.save_for_backward(inp, quant_w, scale_w)
|
| 80 |
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
|
|
|
|
| 86 |
grad_output = grad_output.contiguous().view(-1, weight.size(0))
|
| 87 |
grad_input = grad_output.mm(weight)
|
| 88 |
grad_weight = grad_output.t().mm(inp)
|
| 89 |
+
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
|
| 90 |
|
| 91 |
|
| 92 |
default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
|