File size: 12,737 Bytes
29964ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 |
import math
import numpy as np
import torch
import transformers
from auto_gptq.nn_modules.triton_utils.mixin import TritonModuleMixin
from torch import nn
def weight_quant(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
def weight_quant2(weight, num_bits=1):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1)
return result.type(dtype), 1 / s
def activation_quant(x, num_bits=8):
dtype = x.dtype
x = x.float()
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp) / s
return result.type(dtype)
def optimized_linear(quant_input, weight, scale):
# 重み行列のスパース性を利用して、乗算の代わりに加算と減算を行う
pos_mask = weight == 1
neg_mask = weight == -1
# 加算と減算を行い、結果を集約する
pos_sum = torch.matmul(quant_input, pos_mask.to(quant_input.dtype))
neg_sum = torch.matmul(quant_input, neg_mask.to(quant_input.dtype))
result = pos_sum - neg_sum
# scaleをかける
result *= scale
return result
class BitLinear(nn.Linear):
def __init__(self, *kargs, weight_bits=1, input_bits=8, **kwargs):
super(BitLinear, self).__init__(*kargs, **kwargs)
"""
RMSNorm is placed outside BitLinear
"""
self.weight_bits = weight_bits
self.input_bits = input_bits
self.quant_initialized = False
self.quant_scale = None
def quantize(self):
if not self.quant_initialized:
quant_weight, quant_scale = weight_quant2(self.weight, self.weight_bits)
quant_weight = self.weight + (quant_weight - self.weight).detach()
# print(f'Quantized weight: {quant_weight}')
# print(f'Quantized scale: {quant_scale}')
self.weight.data = quant_weight
self.quant_scale = quant_scale
self.quant_initialized = True
def forward(self, input):
if not self.quant_initialized:
self.quantize()
quant_input = (
input + (activation_quant(input, self.input_bits) - input).detach()
)
out = nn.functional.linear(quant_input, self.weight) * self.quant_scale
if self.bias is not None:
out += self.bias.view(1, -1).expand_as(out)
return out
# Original code from https://github.com/AutoGPTQ/AutoGPTQ
# MIT License
try:
from auto_gptq.nn_modules.triton_utils.kernels import (
QuantLinearFunction,
QuantLinearInferenceOnlyFunction,
quant_matmul_248,
quant_matmul_inference_only_248,
transpose_quant_matmul_248,
)
except ImportError as e:
triton_import_exception = e
def error_raiser_triton(*args, **kwargs):
raise ValueError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
class FakeTriton:
def __getattr__(self, name):
raise ImportError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
quant_matmul_248 = error_raiser_triton
transpose_quant_matmul_248 = error_raiser_triton
quant_matmul_inference_only_248 = error_raiser_triton
QuantLinearFunction = FakeTriton
QuantLinearInferenceOnlyFunction = FakeTriton
class QuantizedBitLinear(nn.Module, TritonModuleMixin):
QUANT_TYPE = "triton"
def __init__(
self,
infeatures,
outfeatures,
bias,
weight_bits=1,
input_bits=8,
quant_bits=2,
group_size=128,
trainable=False,
**kwargs,
):
super().__init__()
if quant_bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError(
"in_feature and out_feature must be divisible by 32."
)
self.infeatures = infeatures
self.outfeatures = outfeatures
self.weight_bits = weight_bits
self.input_bits = input_bits
self.quant_bits = quant_bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2**self.quant_bits - 1
self.register_buffer(
"qweight",
torch.zeros(
(infeatures // 32 * self.quant_bits, outfeatures), dtype=torch.int32
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.quant_bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor(
[i // self.group_size for i in range(infeatures)], dtype=torch.int32
),
)
if bias:
self.register_buffer(
"bias", torch.zeros((outfeatures), dtype=torch.float16)
)
else:
self.bias = None
self.register_buffer("scale", torch.tensor(1.0, dtype=torch.float16))
self.trainable = trainable
def post_init(self):
pass
def pack(self, bitlinear: BitLinear):
device = bitlinear.weight.device
bitlinear = bitlinear.cpu()
W = bitlinear.weight.data.clone()
if isinstance(bitlinear, nn.Conv2d):
W = W.flatten(1)
if isinstance(bitlinear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.scale = torch.tensor(bitlinear.quant_scale, dtype=torch.float16)
# self.scales.fill_(self.scale).half()
# self.scales.fill_(1).half()
scales = torch.ones(
self.outfeatures,
math.ceil(self.infeatures / self.group_size),
)
zero = 1
zeros = torch.zeros(
self.outfeatures,
math.ceil(self.infeatures / self.group_size),
)
zeros.fill_(zero)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if bitlinear.bias is not None:
self.bias = bitlinear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(W[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
print(f"Int weight: {intweight}")
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.quant_bits, intweight.shape[1]),
dtype=np.uint32,
)
while row < qweight.shape[0]:
if self.quant_bits in [2, 4, 8]:
for j in range(i, i + (32 // self.quant_bits)):
qweight[row] |= intweight[j] << (self.quant_bits * (j - i))
i += 32 // self.quant_bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
print(f"Quantized weight: {self.qweight}")
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.quant_bits), dtype=np.uint32
) # math.ceil(infeatures / self.group_size), outfeatures // 32 * self.quant_bits,
i = 0
col = 0
while col < qzeros.shape[1]:
if self.quant_bits in [2, 4, 8]:
for j in range(i, i + (32 // self.quant_bits)):
qzeros[:, col] |= zeros[:, j] << (self.quant_bits * (j - i))
i += 32 // self.quant_bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
self.to(device)
# zeros -= 1
# zeros = zeros.numpy().astype(np.uint32)
# qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.quant_bits), dtype=np.uint32)
# i = 0
# col = 0
# while col < qzeros.shape[1]:
# if self.quant_bits in [2, 4, 8]:
# for j in range(i, i + (32 // self.quant_bits)):
# qzeros[:, col] |= zeros[:, j] << (self.quant_bits * (j - i))
# i += 32 // self.quant_bits
# col += 1
# else:
# raise NotImplementedError("Only 2,4,8 bits are supported.")
# qzeros = qzeros.astype(np.int32)
# self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
# out_shape = x.shape[:-1] + (self.outfeatures,)
# quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
# out = quant_linear_fn.apply(
# x.reshape(-1, x.shape[-1]),
# self.qweight,
# self.scales,
# self.qzeros,
# self.g_idx,
# self.quant_bits,
# self.maxq,
# )
# out = out.half().reshape(out_shape)
# out = out + self.bias if self.bias is not None else out
# return out
x = x + (activation_quant(x, self.input_bits) - x).detach()
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = (
QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
)
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.quant_bits,
self.maxq,
)
out *= self.scale
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (
m.qweight,
m.scales,
m.qzeros,
m.g_idx,
m.bits,
m.maxq,
)
# logger.info(f"Found {len(kn_values)} unique KN Linear values.")
# logger.info("Warming up autotune cache ...")
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2**m
for (k, n), (
qweight,
scales,
qzeros,
g_idx,
bits,
maxq,
) in kn_values.items():
if transpose:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
transpose_quant_matmul_248(
a, qweight, scales, qzeros, g_idx, bits, maxq
)
else:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_inference_only_248(
a, qweight, scales, qzeros, g_idx, bits, maxq
)
del kn_values
|