File size: 11,582 Bytes
8fe54fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac792e3
 
 
 
8fe54fa
 
 
 
ac792e3
8fe54fa
 
ac792e3
8fe54fa
ac792e3
8fe54fa
 
 
 
 
 
ac792e3
8fe54fa
ac792e3
8fe54fa
 
 
 
ac792e3
8fe54fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac792e3
8fe54fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac792e3
8fe54fa
 
ac792e3
 
8fe54fa
 
 
ac792e3
8fe54fa
ac792e3
8fe54fa
 
ac792e3
 
8fe54fa
 
 
 
 
ac792e3
 
8fe54fa
 
 
 
ac792e3
8fe54fa
ac792e3
8fe54fa
 
 
 
 
 
ac792e3
8fe54fa
ac792e3
 
 
 
8fe54fa
ac792e3
8fe54fa
 
ac792e3
8fe54fa
ac792e3
 
 
 
 
 
8fe54fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac792e3
8fe54fa
 
 
 
ac792e3
8fe54fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac792e3
 
8fe54fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

import contextlib
import logging
import math
from typing import List, Optional

import torch
import transformers
from torch import nn

LOGGER = logging.getLogger(__name__)

QUANT_LAYERS = [nn.Linear, nn.Conv2d, transformers.Conv1D]

def is_transformer_conv1d(layer):
    return isinstance(layer, transformers.Conv1D)


# These two functions only work on per-channel symmetric quantization for weight
def get_weight_scale(weight, weight_bit_width):
    weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
    return weight_scale

def fake_quantize_weight(weight, weight_scale):
    weight_scale = weight_scale[:, None]
    fake_quantized_weight = torch.round(weight / weight_scale) * weight_scale
    return fake_quantized_weight


class GPTQLayerWrapper:
    def __init__(self, layer_name, layer, weight_bit_width):
        super().__init__()
        self.layer_name = layer_name
        self.layer = layer
        self.device = layer.weight.device
        columns = layer.weight.shape[1]
        self.columns = columns
        self.H = torch.zeros((columns, columns), device=self.device)
        self.nsamples = 0
        self.is_record = True
        self.weight_bit_width = weight_bit_width
        self.weight_scale = None

    def record_h(self, x):
        if self.is_record:
            x = x.detach().clone()
            if len(x.shape) == 2:
                x = x.unsqueeze(0)
            batch = x.shape[0]
            if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer):
                if len(x.shape) == 3:
                    x = x.reshape((-1, x.shape[-1]))
                x = x.t()

            if isinstance(self.layer, nn.Conv2d):
                unfold = nn.Unfold(
                    self.layer.kernel_size,
                    dilation=self.layer.dilation,
                    padding=self.layer.padding,
                    stride=self.layer.stride
                )
                x = unfold(x)
                x = x.permute([1, 0, 2])
                x = x.flatten(1)

            self.H *= self.nsamples / (self.nsamples + batch)
            self.nsamples += batch
            x = math.sqrt(2 / self.nsamples) * x.float()
            self.H += x.matmul(x.t())

    def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1):
        if groupsize != -1:
            raise RuntimeError("Group quantization of gptq quantizer is not supported for now")
        weight = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            weight = weight.flatten(1)
        if is_transformer_conv1d(self.layer):
            weight = weight.t()
        weight = weight.float()

        weight_scale = get_weight_scale(weight, self.weight_bit_width)
        # todo: use buffer to store scale
        self.weight_scale = weight_scale
        H = self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        weight[:, dead] = 0

        losses = torch.zeros_like(weight)
        Q = torch.zeros_like(weight)

        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.device)
        H[diag, diag] += damp
        try:
            H = torch.linalg.cholesky(H)
            H = torch.cholesky_inverse(H)
            H = torch.linalg.cholesky(H, upper=True)
        except Exception:
            logging.warning(f"Warning:  cannot do compression on layer {self.layer_name} because of inverse error")
            return

        if H.isnan().any():
            logging.warning(f"Warning:  cannot do compression on layer {self.layer_name} because of inverse error")
            return

        hinv = H

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            w1 = weight[:, i1:i2].clone()
            q1 = torch.zeros_like(w1)
            total_err = torch.zeros_like(w1)
            losses1 = torch.zeros_like(w1)
            hinv1 = hinv[i1:i2, i1:i2]

            for i in range(count):
                w = w1[:, i]
                d = hinv1[i, i]

                q = fake_quantize_weight(w.unsqueeze(1), weight_scale).flatten()

                q1[:, i] = q
                losses1[:, i] = (w - q) ** 2 / d ** 2
                err = (w - q) / d
                w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0))
                total_err[:, i] = err

            Q[:, i1:i2] = q1
            losses[:, i1:i2] = losses1 / 2

            weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:])

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        if is_transformer_conv1d(self.layer):
            Q = Q.t()
        shape = self.layer.weight.shape
        dtype = self.layer.weight.data.dtype
        del self.layer.weight
        setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
        del self.H


class GPTQBlockWrapper:
    def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
        self.layer_wrappers = {}
        self.hook_handles = []
        # block order in the whole network
        self.order = 0
        self.block_name = block_name

        def get_hook(layer_name):
            def record_hook(_, x):
                self.layer_wrappers[layer_name].record_h(x[0])
            return record_hook

        for layer_name, layer in block.named_modules():
            if isinstance(layer, tuple(QUANT_LAYERS)):
                full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
                self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
                handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
                self.hook_handles.append(handle)

    def quant_block(self):
        for _, wrapper in self.layer_wrappers.items():
            wrapper.quant_weight()

        for h in self.hook_handles:
            h.remove()

    def set_order(self, idx):
        self.order = idx

    def get_order(self):
        return self.order

    def enable(self):
        for n, l in self.layer_wrappers.items():
            l.is_record = True

    def disable(self):
        for n, l in self.layer_wrappers.items():
            l.is_record = False


class GPTQuantizer:
    def __init__(self, block_type: Optional[List[type]] = None):
        self.gptq_block_wrappers = {}
        self.block_type = block_type

    def wrap_model(self, model: nn.Module, weight_bit_width=8):

        def wrap_block(m, prefix=""):
            for name, child in m.named_children():
                child_prefix = f"{prefix}.{name}" if prefix else name
                if isinstance(child, tuple(self.block_type)):
                    self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
                    LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
                else:
                    wrap_block(child, child_prefix)

        wrap_block(model)
        return model

    @property
    def calibration_iters(self):
        return len(self.gptq_block_wrappers)

    @contextlib.contextmanager
    def record_order(self):
        counter = 0
        record_handles = []
        orders = {}
        try:
            def get_record_order_hook(block_name):
                def record_hook(*args, **kwargs):
                    nonlocal counter
                    if block_name not in orders:
                        orders[block_name] = counter
                        counter += 1
                return record_hook

            for block_name, block_wrapper in self.gptq_block_wrappers.items():
                # disable the record
                for _, layer_wrapper in block_wrapper.layer_wrappers.items():
                    layer_wrapper.is_record = False

                one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
                handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
                record_handles.append(handles)
            yield
        except Exception as e:
            logging.warning(e)
        finally:
            for block_name, order in orders.items():
                self.gptq_block_wrappers[block_name].set_order(order)

            for h in record_handles:
                h.remove()

            for _, block_wrapper in self.gptq_block_wrappers.items():
                # disable the record
                for _, layer_wrapper in block_wrapper.layer_wrappers.items():
                    layer_wrapper.is_record = True


    @contextlib.contextmanager
    def start_calib_iter(self, i):
        assert i < len(self.gptq_block_wrappers)
        target_block_wrapper = None
        try:
            for _, block_wrapper in self.gptq_block_wrappers.items():
                if block_wrapper.get_order() == i:
                    block_wrapper.enable()
                    target_block_wrapper = block_wrapper
                else:
                    block_wrapper.disable()
            yield
        finally:
            target_block_wrapper.quant_block()

    def release_reference(self):
        # delete reference so that `torch.cuda.empty_cache()` can
        # release all the gpu memory cache used during calibration
        for _, block_wrapper in self.gptq_block_wrappers.items():
            for _, layer_wrapper in block_wrapper.layer_wrappers.items():
                del layer_wrapper.layer

        torch.cuda.empty_cache()


def locate_parent(root: nn.Module, full_path: str):
    parent = root
    path = full_path.split('.')
    for p in path[:-1]:
        parent = getattr(parent, p)
    return parent, path[-1]


@torch.no_grad()
def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
    from .modeling_chatglm import GLMBlock
    from .quantization import QuantizedLinear

    quantizer = GPTQuantizer([GLMBlock])
    calib_model = quantizer.wrap_model(model, weight_bit_width)
    with quantizer.record_order():
        calib_model.chat(tokenizer, calib_data[0], history=[])

    logging.info("Start doing calibration using GPTQ ")
    for i in range(quantizer.calibration_iters):
        logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
        # todo: should add early return to speed up the calibration
        # todo: add cpu offload to reduce the gpu memory requirements.
        with quantizer.start_calib_iter(i):
            for prompt in calib_data:
                model.chat(tokenizer, prompt, history=[])

    # replace the fp16 linear with quantized linear
    for _, block_wrapper in quantizer.gptq_block_wrappers.items():
        for layer_name, layer_wrapper in block_wrapper.layer_wrappers.items():
            layer = layer_wrapper.layer
            parent, name_in_parent = locate_parent(model, layer_name)
            quantized_layer = QuantizedLinear(
                weight_bit_width=weight_bit_width,
                weight_tensor=layer.weight,
                bias_tensor=layer.bias,
                weight_scale=layer_wrapper.weight_scale,
                in_features=layer.in_features,
                out_features=layer.out_features,
                bias=True,
                dtype=torch.half,
                device=layer_wrapper.device,
                empty_init=False
            )
            parent.add_module(name_in_parent, quantized_layer)

    # release the memory caache during calibration
    quantizer.release_reference()
    return