File size: 5,507 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

from fairseq.modules.quantization import pq, quantization_options, scalar
from omegaconf import DictConfig


logger = logging.getLogger(__name__)


def quantize_model_scalar(model, model_cfg: DictConfig):
    quant_noise_scalar = getattr(model_cfg, "quant_noise_scalar", 0) or 0
    if quant_noise_scalar > 0:
        # quantize_model edits the model in place
        scalar.quantize_model_(model, p=quant_noise_scalar, bits=8, update_step=1000)
    return model


class Quantizer(object):
    def __init__(self, config_path, max_epoch, max_update):
        try:
            import yaml
        except ImportError:
            raise ImportError("Please install yaml with: pip install yaml")

        # parse config
        if config_path:
            with open(config_path) as config_file:
                config = quantization_options.parse_config_yaml(
                    yaml.safe_load(config_file)
                )
        else:
            config = quantization_options.parse_config_yaml({})

        self.n_centroids_config = config["n_centroids"]
        self.block_sizes_config = config["block_sizes"]
        self.layers_to_quantize = config["layers_to_quantize"]

        # We assume that training will run for a fixed number of epochs
        # (or updates) and that we should train for equal durations
        # between iterations of PQ.
        num_iterations = len(self.layers_to_quantize)
        if max_epoch > 0:
            assert max_epoch % num_iterations == 0, (
                "for iterative PQ, --max-epoch (={}) must be evenly divisible by "
                "len(layers_to_quantize) (={})".format(max_epoch, num_iterations)
            )
            self.epoch_schedule = max_epoch // num_iterations
        else:
            self.epoch_schedule = None
        if max_update > 0:
            assert max_update % num_iterations == 0, (
                "for iterative PQ, --max-update (={}) must be evenly divisible by "
                "len(layers_to_quantize) (={})".format(max_update, num_iterations)
            )
            self.update_schedule = max_update // num_iterations
        else:
            self.update_schedule = None
        assert (self.epoch_schedule is not None) ^ (
            self.update_schedule is not None
        ), "for iterative PQ, cannot specify both --max-update and --max-epoch"

        # 0 is a special value for quantization step, which will force
        # the first call to begin_epoch() to call step()
        self.quantization_step = 0

    def set_trainer(self, trainer):
        self.trainer = trainer
        self.size_tracker = pq.SizeTracker(self.trainer.get_model())

    def step(self):
        """Move to the next stage of quantization."""
        if self.quantization_step >= len(self.layers_to_quantize):
            # Maybe we just finished the last training step or we loaded
            # a checkpoint for an iterative PQ model which previously
            # finished training. Either way, don't quantize again.
            return

        logger.info(
            "quantizing model (step={}; layers_to_quantize[step]={})".format(
                self.quantization_step, self.layers_to_quantize[self.quantization_step]
            )
        )
        quantized_layers = pq.quantize_model_(
            self.trainer.get_model(),
            self.size_tracker,
            self.layers_to_quantize,
            self.block_sizes_config,
            self.n_centroids_config,
            step=self.quantization_step,
        )
        logger.info("quantized layers: {}".format(quantized_layers))
        logger.info(self.size_tracker)

        self.quantization_step += 1

        # reintialize the Trainer since model parameters have changed
        self.trainer.reinitialize()

    def begin_epoch(self, epoch):
        """Called at the beginning of each epoch (epochs start at 1)."""
        if (
            (
                self.epoch_schedule is not None
                and epoch > 0
                and (epoch - 1) % self.epoch_schedule == 0
            )
            # we always step once in the beginning, even if using
            # update-based quantization
            or self.quantization_step == 0
        ):
            self.step()

    def step_update(self, num_updates):
        """Called at the end of each step."""
        if (
            self.update_schedule is not None
            and num_updates > 0
            and num_updates % self.update_schedule == 0
        ):
            self.step()

    def state_dict(self):
        return {
            "n_centroids_config": self.n_centroids_config,
            "block_sizes_config": self.block_sizes_config,
            "layers_to_quantize": self.layers_to_quantize,
            "epoch_schedule": self.epoch_schedule,
            "update_schedule": self.update_schedule,
            "quantization_step": self.quantization_step,
        }

    def load_state_dict(self, state_dict):
        self.n_centroids_config = state_dict["n_centroids_config"]
        self.block_sizes_config = state_dict["block_sizes_config"]
        self.layers_to_quantize = state_dict["layers_to_quantize"]
        self.epoch_schedule = state_dict["epoch_schedule"]
        self.update_schedule = state_dict["update_schedule"]
        self.quantization_step = state_dict["quantization_step"]