JustinLin610
update
8437114
raw history blame
No virus
4.01 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
def quant_noise(module, p, block_size):
"""
Wraps modules and applies quantization noise to the weights for
subsequent quantization with Iterative Product Quantization as
described in "Training with Quantization Noise for Extreme Model Compression"
Args:
- module: nn.Module
- p: amount of Quantization Noise
- block_size: size of the blocks for subsequent quantization with iPQ
Remarks:
- Module weights must have the right sizes wrt the block size
- Only Linear, Embedding and Conv2d modules are supported for the moment
- For more detail on how to quantize by blocks with convolutional weights,
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
- We implement the simplest form of noise here as stated in the paper
which consists in randomly dropping blocks
"""
# if no quantization noise, don't register hook
if p <= 0:
return module
# supported modules
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
# test whether module.weight has the right sizes wrt block_size
is_conv = module.weight.ndim == 4
# 2D matrix
if not is_conv:
assert (
module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
else:
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
k = module.kernel_size[0] * module.kernel_size[1]
assert k % block_size == 0, "Kernel size must be a multiple of block size"
def _forward_pre_hook(mod, input):
# no noise for evaluation
if mod.training:
if not is_conv:
# gather weight and sizes
weight = mod.weight
in_features = weight.size(1)
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
mask = torch.zeros(
in_features // block_size * out_features, device=weight.device
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
# gather weight and sizes
weight = mod.weight
in_channels = mod.in_channels
out_channels = mod.out_channels
# split weight matrix into blocks and randomly drop selected blocks
if mod.kernel_size == (1, 1):
mask = torch.zeros(
int(in_channels // block_size * out_channels),
device=weight.device,
)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
mask = torch.zeros(
weight.size(0), weight.size(1), device=weight.device
)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
.unsqueeze(3)
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
mask = mask.to(
torch.bool
) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_forward_pre_hook)
return module