codegeex2-6b / quantization.py
Stanislas's picture
Initial commit
711c7bf
raw history blame
No virus
17.3 kB
from torch.nn import Linear
from torch.nn.parameter import Parameter
import bz2
import torch
import base64
import ctypes
from transformers.utils import logging
from typing import List
from functools import partial
logger = logging.get_logger(__name__)
try:
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up
class Kernel:
def __init__(self, code: bytes, function_names: List[str]):
self.code = code
self._function_names = function_names
self._cmodule = LazyKernelCModule(self.code)
for name in self._function_names:
setattr(self, name, KernelFunction(self._cmodule, name))
quantization_code = ""
kernels = Kernel(
bz2.decompress(base64.b64decode(quantization_code)),
[
"int4WeightCompression",
"int4WeightExtractionFloat",
"int4WeightExtractionHalf",
"int4WeightExtractionBFloat16",
"int8WeightExtractionFloat",
"int8WeightExtractionHalf",
"int8WeightExtractionBFloat16",
],
)
except Exception as exception:
kernels = None
logger.warning("Failed to load cpm_kernels:" + str(exception))
class W8A16Linear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size()
ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
ctx.weight_shape = weight.size()
output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
inp, quant_w, scale_w = ctx.saved_tensors
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None
def compress_int4_weight(weight: torch.Tensor): # (n, m)
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
assert m % 2 == 0
m = m // 2
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
kernels.int4WeightCompression(
gridDim,
blockDim,
0,
stream,
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
)
return out
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
assert scale_list.dtype in [torch.half, torch.bfloat16]
assert weight.dtype in [torch.int8]
if source_bit_width == 8:
return weight.to(scale_list.dtype) * scale_list[:, None]
elif source_bit_width == 4:
func = (
kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16
)
else:
assert False, "Unsupported bit-width"
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda")
stream = torch.cuda.current_stream()
gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)
func(
gridDim,
blockDim,
0,
stream,
[
ctypes.c_void_p(weight.data_ptr()),
ctypes.c_void_p(scale_list.data_ptr()),
ctypes.c_void_p(out.data_ptr()),
ctypes.c_int32(n),
ctypes.c_int32(m),
],
)
return out
class QuantizedLinear(torch.nn.Module):
def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args,
**kwargs):
super().__init__()
self.weight_bit_width = weight_bit_width
shape = weight.shape
if weight is None or empty_init:
self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device)
self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device)
else:
self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)
self.weight = Parameter(self.weight.to(device), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False)
self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None
def forward(self, input):
output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width)
if self.bias is not None:
output = output + self.bias
return output
def quantize(model, weight_bit_width, empty_init=False, device=None):
"""Replace fp16 linear with quantized linear"""
for layer in model.layers:
layer.self_attention.query_key_value = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attention.query_key_value.weight.to(torch.cuda.current_device()),
bias=layer.self_attention.query_key_value.bias,
dtype=layer.self_attention.query_key_value.weight.dtype,
device=layer.self_attention.query_key_value.weight.device if device is None else device,
empty_init=empty_init
)
layer.self_attention.dense = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.self_attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.self_attention.dense.bias,
dtype=layer.self_attention.dense.weight.dtype,
device=layer.self_attention.dense.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias,
dtype=layer.mlp.dense_h_to_4h.weight.dtype,
device=layer.mlp.dense_h_to_4h.weight.device if device is None else device,
empty_init=empty_init
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias,
dtype=layer.mlp.dense_4h_to_h.weight.dtype,
device=layer.mlp.dense_4h_to_h.weight.device if device is None else device,
empty_init=empty_init
)
return model