|
|
|
import torch |
|
import torch.ao.nn.intrinsic |
|
import torch.nn.intrinsic.qat |
|
import torch.nn.functional as F |
|
import torch.ao.nn.quantized as nnq |
|
|
|
from torch.nn.utils import fuse_conv_bn_weights |
|
|
|
_reverse_repeat_padding = nnq.modules.conv._reverse_repeat_padding |
|
|
|
|
|
class ConvReLU1d(nnq.Conv1d): |
|
r""" |
|
A ConvReLU1d module is a fused module of Conv1d and ReLU |
|
|
|
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv1d`. |
|
|
|
Attributes: |
|
Same as torch.ao.nn.quantized.Conv1d |
|
|
|
""" |
|
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU1d |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True, |
|
padding_mode='zeros', device=None, dtype=None): |
|
super(ConvReLU1d, self).__init__( |
|
in_channels, out_channels, kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, groups=groups, bias=bias, |
|
padding_mode=padding_mode, device=device, dtype=dtype) |
|
|
|
def forward(self, input): |
|
|
|
|
|
if len(input.shape) != 3: |
|
raise ValueError("Input shape must be `(N, C, L)`!") |
|
if self.padding_mode != 'zeros': |
|
|
|
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1]) |
|
input = F.pad(input, _reversed_padding_repeated_twice, |
|
mode=self.padding_mode) |
|
return torch.ops.quantized.conv1d_relu( |
|
input, self._packed_params, self.scale, self.zero_point) |
|
|
|
def _get_name(self): |
|
return 'QuantizedConvReLU1d' |
|
|
|
@classmethod |
|
def from_float(cls, mod): |
|
if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU1d: |
|
mod.weight, mod.bias = fuse_conv_bn_weights( |
|
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, |
|
mod.bn.eps, mod.bn.weight, mod.bn.bias) |
|
return super(ConvReLU1d, cls).from_float(mod) |
|
|
|
@classmethod |
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point): |
|
assert type(ref_qconv) != torch.nn.intrinsic.ConvBnReLU1d, \ |
|
"BatchNorm1d should be fused into Conv1d before converting to reference module" |
|
return super().from_reference(ref_qconv[0], output_scale, output_zero_point) |
|
|
|
class ConvReLU2d(nnq.Conv2d): |
|
r""" |
|
A ConvReLU2d module is a fused module of Conv2d and ReLU |
|
|
|
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv2d`. |
|
|
|
Attributes: |
|
Same as torch.ao.nn.quantized.Conv2d |
|
|
|
""" |
|
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU2d |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True, |
|
padding_mode='zeros', device=None, dtype=None): |
|
super(ConvReLU2d, self).__init__( |
|
in_channels, out_channels, kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, groups=groups, bias=bias, |
|
padding_mode=padding_mode, device=device, dtype=dtype) |
|
|
|
def forward(self, input): |
|
|
|
|
|
if len(input.shape) != 4: |
|
raise ValueError("Input shape must be `(N, C, H, W)`!") |
|
if self.padding_mode != 'zeros': |
|
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) |
|
input = F.pad(input, _reversed_padding_repeated_twice, |
|
mode=self.padding_mode) |
|
return torch.ops.quantized.conv2d_relu( |
|
input, self._packed_params, self.scale, self.zero_point) |
|
|
|
def _get_name(self): |
|
return 'QuantizedConvReLU2d' |
|
|
|
@classmethod |
|
def from_float(cls, mod): |
|
if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU2d: |
|
mod.weight, mod.bias = fuse_conv_bn_weights( |
|
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var, |
|
mod.bn.eps, mod.bn.weight, mod.bn.bias) |
|
return super(ConvReLU2d, cls).from_float(mod) |
|
|
|
@classmethod |
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point): |
|
assert type(ref_qconv) != torch.nn.intrinsic.ConvBnReLU2d, \ |
|
"BatchNorm2d should be fused into Conv2d before converting to reference module" |
|
return super().from_reference(ref_qconv[0], output_scale, output_zero_point) |
|
|
|
|
|
class ConvReLU3d(nnq.Conv3d): |
|
r""" |
|
A ConvReLU3d module is a fused module of Conv3d and ReLU |
|
|
|
We adopt the same interface as :class:`torch.ao.nn.quantized.Conv3d`. |
|
|
|
Attributes: Same as torch.ao.nn.quantized.Conv3d |
|
|
|
""" |
|
_FLOAT_MODULE = torch.nn.intrinsic.ConvReLU3d |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
|
padding=0, dilation=1, groups=1, bias=True, |
|
padding_mode='zeros', device=None, dtype=None): |
|
assert padding_mode != 'reflect', "Conv3d does not support reflection padding" |
|
super(ConvReLU3d, self).__init__( |
|
in_channels, out_channels, kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, groups=groups, bias=bias, |
|
padding_mode=padding_mode, device=device, dtype=dtype) |
|
|
|
def forward(self, input): |
|
|
|
|
|
if len(input.shape) != 5: |
|
raise ValueError("Input shape must be `(N, C, D, H, W)`!") |
|
if self.padding_mode != 'zeros': |
|
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding) |
|
input = F.pad(input, _reversed_padding_repeated_twice, |
|
mode=self.padding_mode) |
|
return torch.ops.quantized.conv3d_relu( |
|
input, self._packed_params, self.scale, self.zero_point) |
|
|
|
def _get_name(self): |
|
return 'QuantizedConvReLU3d' |
|
|
|
@classmethod |
|
def from_float(cls, mod): |
|
if type(mod) == torch.nn.intrinsic.qat.ConvBnReLU3d: |
|
mod.weight, mod.bias = fuse_conv_bn_weights( |
|
mod.weight, |
|
mod.bias, |
|
mod.bn.running_mean, |
|
mod.bn.running_var, |
|
mod.bn.eps, |
|
mod.bn.weight, |
|
mod.bn.bias, |
|
) |
|
return super(ConvReLU3d, cls).from_float(mod) |
|
|
|
@classmethod |
|
def from_reference(cls, ref_qconv, output_scale, output_zero_point): |
|
assert type(ref_qconv) != torch.nn.intrinsic.ConvBnReLU3d, \ |
|
"BatchNorm3d should be fused into Conv3d before converting to reference module" |
|
return super().from_reference(ref_qconv[0], output_scale, output_zero_point) |
|
|