File size: 17,250 Bytes
6911e5d
 
 
 
 
 
 
 
 
 
85a27e0
 
6911e5d
 
 
 
 
0a88ce4
6911e5d
 
0a88ce4
6911e5d
 
 
 
85a27e0
 
 
 
 
 
6911e5d
 
 
 
 
 
85a27e0
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
 
 
 
 
 
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
 
 
 
 
 
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
6911e5d
85a27e0
 
 
 
 
 
 
 
6911e5d
85a27e0
 
 
 
6911e5d
85a27e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6911e5d
 
 
 
 
85a27e0
 
6911e5d
 
 
 
 
85a27e0
 
 
6911e5d
 
 
85a27e0
 
 
 
6911e5d
 
 
 
 
 
0a88ce4
 
 
 
 
6911e5d
0a88ce4
 
 
6911e5d
 
0a88ce4
 
6911e5d
0a88ce4
 
 
 
 
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
 
 
 
 
 
 
 
6911e5d
 
 
85a27e0
 
 
 
 
 
 
 
 
6911e5d
 
 
 
85a27e0
6911e5d
 
 
 
 
 
0a88ce4
85a27e0
6911e5d
0a88ce4
 
 
 
 
6911e5d
0a88ce4
85a27e0
 
 
 
 
 
 
0a88ce4
 
 
 
 
 
 
 
 
 
 
85a27e0
 
0a88ce4
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
0a88ce4
 
85a27e0
 
 
 
 
 
0a88ce4
 
 
 
 
 
 
 
6911e5d
 
0a88ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
0a88ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
0a88ce4
6911e5d
85a27e0
 
0a88ce4
6911e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a27e0
 
 
 
 
 
 
6911e5d
 
 
 
 
 
 
 
 
 
 
 
85a27e0
6911e5d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
import sys

MIN_PYTHON_VERSION = (3, 7)

if sys.version_info < MIN_PYTHON_VERSION:
    raise ImportError("This script requires Python 3.7 or higher!")

import argparse
import os
from dataclasses import dataclass, field
from typing import Dict, Tuple
from enum import Enum, auto

import numpy as np
import onnx
from onnx import helper

BITS_TO_NUMPY_TYPE = {8: np.int8, 16: np.int16}


SUPPORTED_OPS = {"Conv", "Gemm", "MatMul"}

ONNX_OPSET = 21


class WeightCategory(Enum):
    INITIALIZER = auto()
    CONSTANT = auto()
    NONE = auto()


@dataclass
class BlockQuantizeConfig:
    input_model_path: str
    output_model_path: str
    block_size: int
    bits: int
    verbose: bool


@dataclass
class BlockQuantizeResult:
    quantized_weights: np.ndarray = field(default_factory=lambda: np.array([]))
    scales: np.ndarray = field(default_factory=lambda: np.array([]))
    zero_point: np.ndarray = field(default_factory=lambda: np.array([]))
    block_size: int = 1
    axis: int = 1
    original_shape: Tuple = field(default_factory=tuple)
    quantization_error: np.ndarray = field(default_factory=lambda: np.array([]))


def closest_divisor(number: int, divisor: int) -> int:
    for d in range(divisor, 0, -1):
        if number % d == 0:
            return d
    return 1


def block_dequantize_tensor(
    x: np.ndarray, block_axis: int, scale: np.ndarray, zero_point: np.ndarray
) -> np.ndarray:
    repeats = x.shape[block_axis] // scale.shape[block_axis]

    x_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis)
    x_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis)

    y = (
        x.astype(np.float32) - x_zero_point_elementwise.astype(np.float32)
    ) * x_scale_elementwise

    return y


def block_quantize_tensor(
    x: np.ndarray,
    block_axis: int,
    scale: np.ndarray,
    zero_point: np.ndarray,
    n_bits: int,
) -> np.ndarray:
    repeats = x.shape[block_axis] // scale.shape[block_axis]

    y_scale_elementwise = np.repeat(scale, repeats=repeats, axis=block_axis)
    y_zero_point_elementwise = np.repeat(zero_point, repeats=repeats, axis=block_axis)

    type_info = np.iinfo(BITS_TO_NUMPY_TYPE[n_bits])
    min_value = type_info.min
    max_value = type_info.max

    y = np.rint(x / y_scale_elementwise + y_zero_point_elementwise)
    y = np.clip(y, min_value, max_value)
    y = y.astype(BITS_TO_NUMPY_TYPE[n_bits])

    return y


def create_dequantize_node(
    node_name,
    quantized_weights,
    scales,
    zero_point,
    dequantized_weights,
    block_size,
    axis,
) -> onnx.NodeProto:
    block_size_attr = helper.make_attribute("block_size", block_size)
    axis_attr = helper.make_attribute("axis", axis)

    n = helper.make_node(
        "DequantizeLinear",
        inputs=[quantized_weights, scales, zero_point],
        outputs=[dequantized_weights],
        name=node_name,
    )
    n.attribute.extend([block_size_attr, axis_attr])
    return n


def create_reshape_node(
    node_name, dequantized_weights, shape_tensor, reshaped_weights_name
) -> onnx.NodeProto:
    return helper.make_node(
        "Reshape",
        inputs=[dequantized_weights, shape_tensor],
        outputs=[reshaped_weights_name],
        name=node_name,
    )


class BlockQuantizer:
    def __init__(self, conf: BlockQuantizeConfig) -> None:
        self.conf = conf
        self.validate_conf()

        self.model = onnx.load(conf.input_model_path)

        if self.model.opset_import[0].version != ONNX_OPSET:
            self.model = onnx.version_converter.convert_version(self.model, ONNX_OPSET)

        self.graph = self.model.graph
        self.initializers_map = {
            init.name: init for init in self.model.graph.initializer
        }
        self.costants_map = {
            node.output[0]: next(
                attr.t for attr in node.attribute if attr.name == "value"
            )
            for node in self.model.graph.node
            if node.op_type == "Constant"
        }

    def validate_conf(self):
        if not os.path.isfile(self.conf.input_model_path):
            raise ValueError(
                f"Input model path '{self.conf.input_model_path}' does not exist or is not a file."
            )

        if not self.conf.input_model_path.lower().endswith(".onnx"):
            raise ValueError(
                f"Input model path '{self.conf.input_model_path}' must have a .onnx extension."
            )

        if not self.conf.output_model_path.lower().endswith(".onnx"):
            raise ValueError(
                f"Output model path '{self.conf.output_model_path}' must have a .onnx extension."
            )

        if self.conf.block_size <= 0:
            raise ValueError("Block size must be a positive integer.")

        if self.conf.bits not in BITS_TO_NUMPY_TYPE:
            allowed_values = ", ".join([str(k) for k in BITS_TO_NUMPY_TYPE.keys()])
            raise ValueError(
                f"Bits must be one of the following values: [{allowed_values}]."
            )

    def get_weight_category(self, name: str) -> WeightCategory:
        if name in self.initializers_map:
            return WeightCategory.INITIALIZER
        if name in self.costants_map:
            return WeightCategory.CONSTANT
        else:
            return WeightCategory.NONE

    def get_weight_tensor(self, name: str, category: WeightCategory) -> np.ndarray:
        if category == WeightCategory.INITIALIZER:
            return onnx.numpy_helper.to_array(self.initializers_map[name])
        elif category == WeightCategory.CONSTANT:
            return onnx.numpy_helper.to_array(self.costants_map[name])
        else:
            raise AssertionError("Invalid weight category")

    def remove_fp32_weights(self, name: str, category: WeightCategory):
        if category == WeightCategory.INITIALIZER:
            self.graph.initializer.remove(
                next(init for init in self.graph.initializer if init.name == name)
            )
        elif category == WeightCategory.CONSTANT:
            self.graph.node.remove(
                next(
                    node
                    for node in self.graph.node
                    if node.op_type == "Constant" and node.output[0] == name
                )
            )
        else:
            raise AssertionError("Invalid weight category")

    def compute_scale_zeropoint(
        self, b_min: np.ndarray, b_max: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        assert (
            b_min <= b_max
        ).all(), "minimum must not be greater than maximum when computing scale and zero point"

        # zero must be present in the range, this enforces qmin <= zero_point <= qmax
        b_min = np.minimum(b_min, np.zeros_like(b_min, dtype=b_min.dtype))
        b_max = np.maximum(b_max, np.zeros_like(b_max, dtype=b_max.dtype))

        type_info = np.iinfo(BITS_TO_NUMPY_TYPE[self.conf.bits])
        qmin = type_info.min
        qmax = type_info.max

        dq = qmax - qmin

        scales = np.where(b_max != b_min, (b_max - b_min) / dq, 1.0)

        zeropoints = np.where(b_max != b_min, np.rint(qmin - b_min / scales), 0.0)
        zeropoints = zeropoints.astype(BITS_TO_NUMPY_TYPE[self.conf.bits])

        return (scales, zeropoints)

    def block_quantize(self, weight: np.ndarray) -> BlockQuantizeResult:
        original_shape = weight.shape

        if weight.ndim > 1:
            weight = weight.reshape((weight.shape[0], -1))
            quantization_axis = 1
        else:
            quantization_axis = 0

        block_size = closest_divisor(
            weight.shape[quantization_axis], self.conf.block_size
        )

        assert (
            weight.shape[quantization_axis] % block_size == 0
        ), f"weight shape ({weight.shape[quantization_axis]}) must be divisible by block size ({block_size})"

        # Flattening the tensor after the quantization axis
        new_shape = list(weight.shape[: quantization_axis + 1]) + [-1]
        new_shape[quantization_axis] = new_shape[quantization_axis] // block_size

        blocked_weight = weight.reshape(new_shape)

        blocked_max = np.max(blocked_weight, -1)
        blocked_min = np.min(blocked_weight, -1)

        scales, zeropoints = self.compute_scale_zeropoint(blocked_min, blocked_max)

        quantized_weight = block_quantize_tensor(
            weight, quantization_axis, scales, zeropoints, self.conf.bits
        )
        reconstructed_mat = block_dequantize_tensor(
            quantized_weight, quantization_axis, scales, zeropoints
        )

        # Relative Norm
        qerror = np.linalg.norm(reconstructed_mat - weight) / (np.linalg.norm(weight) + 1e-10)

        res = BlockQuantizeResult(
            quantized_weight,
            scales,
            zeropoints,
            block_size,
            quantization_axis,
            original_shape,
            qerror,
        )

        return res

    def get_model_size(self, model_path: str) -> float:
        size_bytes = os.path.getsize(model_path)
        size_mb = size_bytes / 1024

        return size_mb

    def display_summary(self, sqe: Dict[str, int]):
        sqe_v = list(sqe.values())
        if len(sqe_v) == 0:
            mse = 0
            print(
                "Warning: No weights have been quantized, likely due to unsupported layers."
            )
        else:
            mse = sum(sqe_v) / len(sqe_v)
        original_model_size = self.get_model_size(self.conf.input_model_path)
        quantized_model_size = self.get_model_size(self.conf.output_model_path)

        if self.conf.verbose:
            sorted_sqe = sorted(sqe.items(), key=lambda item: item[1], reverse=True)
            longest_key_len = max(len(key) for key in sqe.keys())
            
            print("Quantization error (Relative Norm) sorted in ascending order:")

            for key, value in sorted_sqe:
                print(f"{key:<{longest_key_len}} : {value}")

        print("Done! Results saved in", self.conf.output_model_path)
        print("\nSummary of Results:\n")
        print(f"{'Metric':<30} {'Value':<10}")
        print(f"{'-'*40}")
        print(f"{'Relative Norm Error':<31} {mse:.6f}")
        print(f"{'Original Model Size (KB)':<31} {original_model_size:,.2f}")
        print(f"{'Block-Quantized Model Size (KB)':<30} {quantized_model_size:,.2f}")

    def run(self):
        print("Quantizing the model...")

        quantized_inputs = []
        sqe = {}

        node_idx = 0

        while node_idx < len(self.model.graph.node):
            node = self.model.graph.node[node_idx]

            if node.op_type in SUPPORTED_OPS:
                for input_idx, input_name in enumerate(node.input):
                    weightCategory = self.get_weight_category(input_name)

                    # Skip quantization if weights are taken as external input
                    if weightCategory == WeightCategory.NONE:
                        continue

                    weight = self.get_weight_tensor(input_name, weightCategory)

                    quantized_weights_name = f"{input_name}_quantized"
                    quantized_node_name = f"{input_name}_quantized_node"
                    dequantized_weights_name = f"{input_name}_dequantized"
                    scales_name = f"{input_name}_scales"
                    zero_point_name = f"{input_name}_zero_point"

                    shape_node_name = f"{input_name}_shape_node"
                    shape_name = f"{input_name}_shape"
                    reshaped_weights_name = f"{input_name}_reshaped"

                    # Skip quantization if weights don't contain enough elements to create at least 1 block
                    if weight.size < self.conf.block_size:
                        continue

                    reshape_needed = weight.ndim > 2

                    # In case of parameter sharing
                    if input_name in quantized_inputs:
                        node.input[input_idx] = (
                            reshaped_weights_name
                            if reshape_needed
                            else dequantized_weights_name
                        )
                        continue


                    block_quantize_res = self.block_quantize(weight)

                    # Skip quantization if it wouldn't reduce the model size
                    if block_quantize_res.block_size == 1:
                        continue

                    quantized_inputs.append(input_name)

                    dequantize_node = create_dequantize_node(
                        quantized_node_name,
                        quantized_weights_name,
                        scales_name,
                        zero_point_name,
                        dequantized_weights_name,
                        block_quantize_res.block_size,
                        block_quantize_res.axis,
                    )

                    if reshape_needed:
                        reshape_node = create_reshape_node(
                            shape_node_name,
                            dequantized_weights_name,
                            shape_name,
                            reshaped_weights_name,
                        )

                    shape_tensor = onnx.numpy_helper.from_array(
                        np.array(block_quantize_res.original_shape), name=shape_name
                    )
                    scale_initializer = onnx.numpy_helper.from_array(
                        block_quantize_res.scales, name=scales_name
                    )
                    zero_point_initializer = onnx.numpy_helper.from_array(
                        block_quantize_res.zero_point, name=zero_point_name
                    )
                    quantized_weights_initializer = onnx.numpy_helper.from_array(
                        block_quantize_res.quantized_weights,
                        name=quantized_weights_name,
                    )

                    dequantized_weights_info = helper.make_tensor_value_info(
                        dequantized_weights_name,
                        onnx.TensorProto.FLOAT,
                        block_quantize_res.quantized_weights.shape,
                    )

                    if reshape_needed:
                        shape_info = helper.make_tensor_value_info(
                            reshaped_weights_name,
                            onnx.TensorProto.FLOAT,
                            block_quantize_res.original_shape,
                        )

                    self.graph.initializer.extend(
                        [
                            scale_initializer,
                            zero_point_initializer,
                            shape_tensor,
                            quantized_weights_initializer,
                        ]
                    )

                    self.remove_fp32_weights(input_name, weightCategory)

                    node.input[input_idx] = (
                        reshaped_weights_name
                        if reshape_needed
                        else dequantized_weights_name
                    )

                    # Preserving graph nodes topological order
                    if reshape_needed:
                        self.graph.node.insert(0, reshape_node)
                        node_idx += 1

                    self.graph.node.insert(0, dequantize_node)
                    node_idx += 1
                    if reshape_needed:
                        self.graph.value_info.insert(0, shape_info)
                    self.graph.value_info.insert(0, dequantized_weights_info)

                    sqe[input_name] = block_quantize_res.quantization_error

            node_idx += 1

        onnx.checker.check_model(self.model, full_check=True)
        onnx.save(self.model, self.conf.output_model_path)

        self.display_summary(sqe)


def setup_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Blockwise quantization tool")

    parser.add_argument(
        "-i",
        "--input_model",
        type=str,
        help="The path of onnx model to quantize",
        required=True,
    )
    parser.add_argument(
        "-bs",
        "--block_size",
        type=int,
        help="The maximum size of quantization block",
        required=True,
    )
    parser.add_argument(
        "-b",
        "--bits",
        type=int,
        help="Quantization bits",
        choices=[8, 16],
        default=8,
        required=False,
    )
    parser.add_argument(
        "-o",
        "--output_model",
        type=str,
        help="The output model path",
        default="block_quantized_model.onnx",
        required=False,
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="Enable verbose output",
        required=False,
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = setup_args()

    quantization_config = BlockQuantizeConfig(
        input_model_path=args.input_model,
        output_model_path=args.output_model,
        block_size=args.block_size,
        bits=args.bits,
        verbose=args.verbose
    )

    quantizer = BlockQuantizer(quantization_config)
    quantizer.run()