BigMaoGoGoGo commited on
Commit
8fe54fa
·
1 Parent(s): eb55ff0

add gptq quantization

Browse files
Files changed (3) hide show
  1. gptq_quantization.py +332 -0
  2. modeling_chatglm.py +18 -5
  3. quantization.py +17 -3
gptq_quantization.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import contextlib
3
+ import logging
4
+ import math
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+ import transformers
9
+ from torch import nn
10
+
11
+ LOGGER = logging.getLogger(__name__)
12
+
13
+ QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]
14
+
15
+ def is_transformer_conv1d(layer):
16
+ return isinstance(layer, transformers.Conv1D)
17
+
18
+
19
+ # These two functions only work on per-channel symmetric quantization for weight
20
+ def get_weight_scale(weight, weight_bit_width):
21
+ weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
22
+ return weight_scale
23
+
24
+ def fake_quantize_weight(weight, weight_scale):
25
+ weight_scale = weight_scale[:, None]
26
+ fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
27
+ return fake_quantized_weight
28
+
29
+
30
+ class GPTQLayerWrapper:
31
+ def __init__(self, layer_name, layer, weight_bit_width):
32
+ super().__init__()
33
+ self.layer_name = layer_name
34
+ self.layer = layer
35
+ self.device = layer.weight.device
36
+ columns = layer.weight.shape[1]
37
+ self.columns = columns
38
+ self.H = torch.zeros((columns, columns), device=self.device)
39
+ self.nsamples = 0
40
+ self.is_record = True
41
+ self.weight_bit_width = weight_bit_width
42
+ self.weight_scale = None
43
+
44
+ def record_h(self, x):
45
+ if self.is_record:
46
+ x = x.detach().clone()
47
+ if len(x.shape) == 2:
48
+ x = x.unsqueeze(0)
49
+ batch = x.shape[0]
50
+ if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
51
+ if len(x.shape) == 3:
52
+ x = x.reshape((-1, x.shape[-1]))
53
+ x = x.t()
54
+
55
+ if isinstance(self.layer, nn.Conv2d):
56
+ unfold = nn.Unfold(
57
+ self.layer.kernel_size,
58
+ dilation=self.layer.dilation,
59
+ padding=self.layer.padding,
60
+ stride=self.layer.stride
61
+ )
62
+ x = unfold(x)
63
+ x = x.permute([1, 0, 2])
64
+ x = x.flatten(1)
65
+
66
+ self.H *= self.nsamples / (self.nsamples + batch)
67
+ self.nsamples += batch
68
+ x = math.sqrt(2 / self.nsamples) * x.float()
69
+ self.H += x.matmul(x.t())
70
+
71
+ def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
72
+ if groupsize != -1:
73
+ raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
74
+ weight = self.layer.weight.data.clone()
75
+ if isinstance(self.layer, nn.Conv2d):
76
+ weight = weight.flatten(1)
77
+ if is_transformer_conv1d(self.layer):
78
+ weight = weight.t()
79
+ weight = weight.float()
80
+
81
+ weight_scale = get_weight_scale(weight, self.weight_bit_width)
82
+ # todo: use buffer to store scale
83
+ self.weight_scale = weight_scale
84
+ H = self.H
85
+ dead = torch.diag(H) == 0
86
+ H[dead, dead] = 1
87
+ weight[:, dead] = 0
88
+
89
+ losses = torch.zeros_like(weight)
90
+ Q = torch.zeros_like(weight)
91
+
92
+ damp = percdamp * torch.mean(torch.diag(H))
93
+ diag = torch.arange(self.columns, device=self.device)
94
+ H[diag, diag] += damp
95
+ try:
96
+ H = torch.linalg.cholesky(H)
97
+ H = torch.cholesky_inverse(H)
98
+ H = torch.linalg.cholesky(H, upper=True)
99
+ except Exception:
100
+ logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
101
+ return
102
+
103
+ if H.isnan().any():
104
+ logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error")
105
+ return
106
+
107
+ hinv = H
108
+
109
+ for i1 in range(0, self.columns, blocksize):
110
+ i2 = min(i1 + blocksize, self.columns)
111
+ count = i2 - i1
112
+
113
+ w1 = weight[:, i1:i2].clone()
114
+ q1 = torch.zeros_like(w1)
115
+ total_err = torch.zeros_like(w1)
116
+ losses1 = torch.zeros_like(w1)
117
+ hinv1 = hinv[i1:i2, i1:i2]
118
+
119
+ for i in range(count):
120
+ w = w1[:, i]
121
+ d = hinv1[i, i]
122
+
123
+ q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()
124
+
125
+ q1[:, i] = q
126
+ losses1[:, i] = (w - q) ** 2 / d ** 2
127
+ err = (w - q) / d
128
+ w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
129
+ total_err[:, i] = err
130
+
131
+ Q[:, i1:i2] = q1
132
+ losses[:, i1:i2] = losses1 / 2
133
+
134
+ weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])
135
+
136
+ if torch.cuda.is_available():
137
+ torch.cuda.synchronize()
138
+
139
+ if is_transformer_conv1d(self.layer):
140
+ Q = Q.t()
141
+ self.layer.weight = nn.Parameter(Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype), requires_grad=False)
142
+
143
+ del self.H
144
+ if torch.cuda.is_available():
145
+ torch.cuda.empty_cache()
146
+
147
+ def release_gpu_memory(self):
148
+ if hasattr(self, "H"):
149
+ del self.H
150
+
151
+
152
+ class GPTQBlockWrapper:
153
+ def __init__(self, module_name: str, module: nn.Module, weight_bit_width=8):
154
+ self.layer_wrappers = {}
155
+ self.hook_handles = []
156
+ # module order in the whole network
157
+ self.order = 0
158
+ self.module_name = module_name
159
+
160
+ def get_hook(layer_name):
161
+ def record_hook(_, x):
162
+ self.layer_wrappers[layer_name].record_h(x[0])
163
+ return record_hook
164
+
165
+ for layer_name, layer in module.named_modules():
166
+ if isinstance(layer, tuple(QUANT_LAYERS)):
167
+ full_layer_name = f"{module_name}.{layer_name}" if layer_name else f"{module_name}"
168
+ self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
169
+ handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
170
+ self.hook_handles.append(handle)
171
+
172
+ def quant_module(self):
173
+ for _, wrapper in self.layer_wrappers.items():
174
+ wrapper.quant_weight()
175
+
176
+ for h in self.hook_handles:
177
+ h.remove()
178
+
179
+ def set_order(self, idx):
180
+ self.order = idx
181
+
182
+ def get_order(self):
183
+ return self.order
184
+
185
+ def enable(self):
186
+ for n, l in self.layer_wrappers.items():
187
+ l.is_record = True
188
+
189
+ def disable(self):
190
+ for n, l in self.layer_wrappers.items():
191
+ l.is_record = False
192
+
193
+ def release_gpu_memory(self):
194
+ for _, wrapper in self.layer_wrappers.items():
195
+ wrapper.release_gpu_memory()
196
+
197
+
198
+ class GPTQuantizer:
199
+ def __init__(self, block_type: Optional[List[type]] = None):
200
+ self.gptq_block_wrappers = {}
201
+ self.block_type = block_type
202
+
203
+ def wrap_model(self, model: nn.Module, weight_bit_width=8):
204
+
205
+ def wrap_block(m, prefix=""):
206
+ for name, child in m.named_children():
207
+ child_prefix = f"{prefix}.{name}" if prefix else name
208
+ if isinstance(child, tuple(self.block_type)):
209
+ self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
210
+ LOGGER.debug(f"Calibrate module {child_prefix} as a whole block in GPTQ")
211
+ else:
212
+ wrap_block(child, child_prefix)
213
+
214
+ wrap_block(model)
215
+ return model
216
+
217
+ def quantize(self, model: nn.Module):
218
+ for _, module_wrapper in self.gptq_block_wrappers.items():
219
+ module_wrapper.quant_module()
220
+
221
+ return model
222
+
223
+ @property
224
+ def calibration_iters(self):
225
+ return len(self.gptq_block_wrappers)
226
+
227
+ @contextlib.contextmanager
228
+ def record_order(self):
229
+ counter = 0
230
+ record_handles = []
231
+ orders = {}
232
+ try:
233
+ def get_record_order_hook(module_name):
234
+ def record_hook(*args, **kwargs):
235
+ nonlocal counter
236
+ if module_name not in orders:
237
+ orders[module_name] = counter
238
+ counter += 1
239
+ return record_hook
240
+
241
+ for module_name, module_wrapper in self.gptq_block_wrappers.items():
242
+ # disable the record
243
+ for _, layer_wrapper in module_wrapper.layer_wrappers.items():
244
+ layer_wrapper.is_record = False
245
+
246
+ one_layer_wrapper_in_module = list(module_wrapper.layer_wrappers.values())[0]
247
+ handles = one_layer_wrapper_in_module.layer.register_forward_pre_hook(get_record_order_hook(module_name))
248
+ record_handles.append(handles)
249
+ yield
250
+ except Exception as e:
251
+ logging.warning(e)
252
+ finally:
253
+ for module_name, order in orders.items():
254
+ self.gptq_block_wrappers[module_name].set_order(order)
255
+
256
+ for h in record_handles:
257
+ h.remove()
258
+
259
+ for module_name, module_wrapper in self.gptq_block_wrappers.items():
260
+ # disable the record
261
+ for _, layer_wrapper in module_wrapper.layer_wrappers.items():
262
+ layer_wrapper.is_record = True
263
+
264
+
265
+ @contextlib.contextmanager
266
+ def start_calib_iter(self, i):
267
+ assert i < len(self.gptq_block_wrappers)
268
+ target_module_wrapper = None
269
+ try:
270
+ for _, module_wrapper in self.gptq_block_wrappers.items():
271
+ if module_wrapper.get_order() == i:
272
+ module_wrapper.enable()
273
+ target_module_wrapper = module_wrapper
274
+ else:
275
+ module_wrapper.disable()
276
+ yield
277
+ finally:
278
+ target_module_wrapper.quant_module()
279
+
280
+ def release_gpu_memory(self):
281
+ for block_name, block_wrapper in self.gptq_block_wrappers.items():
282
+ block_wrapper.release_gpu_memory()
283
+
284
+ torch.cuda.empty_cache()
285
+
286
+
287
+ def locate_parent(root: nn.Module, full_path: str):
288
+ parent = root
289
+ path = full_path.split('.')
290
+ for p in path[:-1]:
291
+ parent = getattr(parent, p)
292
+ return parent, path[-1]
293
+
294
+
295
+ @torch.no_grad()
296
+ def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
297
+ from .modeling_chatglm import GLMBlock
298
+ from .quantization import QuantizedLinear
299
+
300
+ quantizer = GPTQuantizer([GLMBlock])
301
+ calib_model = quantizer.wrap_model(model, weight_bit_width)
302
+ with quantizer.record_order():
303
+ calib_model.chat(tokenizer, calib_data[0], history=[])
304
+ logging.info("Start doing calibration using GPTQ ")
305
+ for i in range(quantizer.calibration_iters):
306
+ logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
307
+ # todo: should add early return to speed up the calibration
308
+ with quantizer.start_calib_iter(i):
309
+ for prompt in calib_data:
310
+ model.chat(tokenizer, prompt, history=[])
311
+
312
+ # replace the fp16 linear with quantized linear
313
+ for _, block_wrapper in quantizer.gptq_block_wrappers.items():
314
+ for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
315
+ layer = layer_wrapper.layer
316
+ parent, name_in_parent = locate_parent(model, layer_name)
317
+ quantized_layer = QuantizedLinear(
318
+ weight_bit_width=weight_bit_width,
319
+ weight_tensor=layer.weight,
320
+ bias_tensor=layer.bias,
321
+ weight_scale=layer_wrapper.weight_scale,
322
+ in_features=layer.in_features,
323
+ out_features=layer.out_features,
324
+ bias=True,
325
+ dtype=torch.half,
326
+ device=layer_wrapper.device,
327
+ empty_init=False
328
+ )
329
+ parent.add_module(name_in_parent, quantized_layer)
330
+
331
+ torch.cuda.empty_cache()
332
+ return
modeling_chatglm.py CHANGED
@@ -1408,12 +1408,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1408
  break
1409
  yield input_ids
1410
 
1411
- def quantize(self, bits: int, empty_init=False, **kwargs):
 
 
1412
  if bits == 0:
1413
  return
1414
 
1415
- from .quantization import quantize
1416
-
1417
  if self.quantized:
1418
  logger.info("Already quantized.")
1419
  return self
@@ -1421,6 +1423,17 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1421
  self.quantized = True
1422
 
1423
  self.config.quantization_bit = bits
1424
-
1425
- self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
1426
  return self
 
1408
  break
1409
  yield input_ids
1410
 
1411
+ def quantize(
1412
+ self, bits: int, empty_init=False, quant_algo_type: str="min_max",
1413
+ calib_data: Optional[List[str]]=None, tokenizer=None, **kwargs):
1414
  if bits == 0:
1415
  return
1416
 
1417
+ from .quantization import quantize, QuantAlgoType
1418
+ from .gptq_quantization import gptq_quantize
1419
  if self.quantized:
1420
  logger.info("Already quantized.")
1421
  return self
 
1423
  self.quantized = True
1424
 
1425
  self.config.quantization_bit = bits
1426
+ quant_algo_type = QuantAlgoType(quant_algo_type)
1427
+ if quant_algo_type == QuantAlgoType.min_max:
1428
+ self.transformer = quantize(
1429
+ self.transformer, bits, empty_init=empty_init, algo_type=quant_algo_type, calib_data=calib_data, tokenizer=tokenizer, **kwargs)
1430
+ elif quant_algo_type == QuantAlgoType.gptq:
1431
+ if calib_data is None or tokenizer is None:
1432
+ raise RuntimeError("If using gptq to quantize the model, "
1433
+ "calibration data (e.g. some string prompts) and tokenizer should be provided")
1434
+ gptq_quantize(
1435
+ self, tokenizer, bits, calib_data
1436
+ )
1437
+ else:
1438
+ raise RuntimeError("Unsupported quantization algorithm type")
1439
  return self
quantization.py CHANGED
@@ -8,7 +8,7 @@ import ctypes
8
  from transformers.utils import logging
9
 
10
  from typing import List
11
- from functools import partial
12
 
13
  logger = logging.get_logger(__name__)
14
 
@@ -41,6 +41,17 @@ except Exception as exception:
41
  logger.warning("Failed to load cpm_kernels:" + str(exception))
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
44
  class W8A16Linear(torch.autograd.Function):
45
  @staticmethod
46
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
@@ -118,7 +129,7 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc
118
 
119
 
120
  class QuantizedLinear(Linear):
121
- def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs):
122
  super(QuantizedLinear, self).__init__(*args, **kwargs)
123
  self.weight_bit_width = weight_bit_width
124
 
@@ -131,7 +142,10 @@ class QuantizedLinear(Linear):
131
  )
132
  self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
133
  else:
134
- self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
 
 
 
135
  self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
136
  if weight_bit_width == 4:
137
  self.weight = compress_int4_weight(self.weight)
 
8
  from transformers.utils import logging
9
 
10
  from typing import List
11
+ from enum import Enum
12
 
13
  logger = logging.get_logger(__name__)
14
 
 
41
  logger.warning("Failed to load cpm_kernels:" + str(exception))
42
 
43
 
44
+ class QuantAlgoType(Enum):
45
+ min_max = 'min_max'
46
+ gptq = 'gptq'
47
+
48
+ @classmethod
49
+ def _missing_(cls, value):
50
+ supported_types = [e.value for e in cls]
51
+ raise ValueError(f"Unsupported quantization algorithm type. Support list: "
52
+ f"{supported_types}. Got: '{value}'")
53
+
54
+
55
  class W8A16Linear(torch.autograd.Function):
56
  @staticmethod
57
  def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
 
129
 
130
 
131
  class QuantizedLinear(Linear):
132
+ def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, weight_scale=None, empty_init=False, *args, **kwargs):
133
  super(QuantizedLinear, self).__init__(*args, **kwargs)
134
  self.weight_bit_width = weight_bit_width
135
 
 
142
  )
143
  self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"])
144
  else:
145
+ if weight_scale is None:
146
+ self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
147
+ else:
148
+ self.weight_scale = weight_scale
149
  self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8)
150
  if weight_bit_width == 4:
151
  self.weight = compress_int4_weight(self.weight)