duzx16 commited on
Commit
f433128
·
1 Parent(s): aee628d

Add logger

Browse files

Fix backward for W8A16LinearCPU

Files changed (1) hide show
  1. quantization.py +4 -2
quantization.py CHANGED
@@ -12,6 +12,8 @@ from transformers.utils import logging
12
  from typing import List
13
  from functools import partial
14
 
 
 
15
  try:
16
  from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
17
 
@@ -68,11 +70,11 @@ class W8A16LinearCPU(torch.autograd.Function):
68
  @staticmethod
69
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
70
  ctx.inp_shape = inp.size()
71
- ctx.weight_shape = quant_w.size()
72
  ctx.weight_bit_width = weight_bit_width
73
  out_features = quant_w.size(0)
74
  inp = inp.contiguous().view(-1, inp.size(-1))
75
  weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
 
76
  output = inp.mm(weight.t())
77
  ctx.save_for_backward(inp, quant_w, scale_w)
78
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@@ -84,7 +86,7 @@ class W8A16LinearCPU(torch.autograd.Function):
84
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
85
  grad_input = grad_output.mm(weight)
86
  grad_weight = grad_output.t().mm(inp)
87
- return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None
88
 
89
 
90
  default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")
 
12
  from typing import List
13
  from functools import partial
14
 
15
+ logger = logging.get_logger(__name__)
16
+
17
  try:
18
  from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
19
 
 
70
  @staticmethod
71
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, quantization_cache=None):
72
  ctx.inp_shape = inp.size()
 
73
  ctx.weight_bit_width = weight_bit_width
74
  out_features = quant_w.size(0)
75
  inp = inp.contiguous().view(-1, inp.size(-1))
76
  weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache)
77
+ ctx.weight_shape = weight.size()
78
  output = inp.mm(weight.t())
79
  ctx.save_for_backward(inp, quant_w, scale_w)
80
  return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
 
86
  grad_output = grad_output.contiguous().view(-1, weight.size(0))
87
  grad_input = grad_output.mm(weight)
88
  grad_weight = grad_output.t().mm(inp)
89
+ return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
90
 
91
 
92
  default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c")