|
"""PyTorch ResNet |
|
|
|
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with |
|
additional dropout and dynamic global avg/max pool. |
|
|
|
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman |
|
|
|
Copyright 2019, Ross Wightman |
|
""" |
|
import math |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_padding(kernel_size, stride, dilation=1): |
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 |
|
return padding |
|
|
|
|
|
class softball(nn.Module): |
|
def __init__(self, radius2=None, inplace=True): |
|
super(softball, self).__init__() |
|
self.radius2 = radius2 if radius2 is not None else None |
|
|
|
def forward(self, x): |
|
if self.radius2 is None: |
|
self.radius2 = x.size()[1] |
|
norm = torch.sqrt(1 + (x*x).sum(1, keepdim=True) / self.radius2) |
|
return x / norm |
|
|
|
class hardball(nn.Module): |
|
def __init__(self, radius2=None): |
|
super(hardball, self).__init__() |
|
self.radius = np.sqrt(radius2) if radius2 is not None else None |
|
|
|
def forward(self, x): |
|
norm = torch.sqrt((x*x).sum(1, keepdim=True)) |
|
if self.radius is None: |
|
self.radius = np.sqrt(x.size()[1]) |
|
return torch.where(norm > self.radius, self.radius * x / norm, x) |
|
|
|
|
|
class ConvBN(nn.Module): |
|
def __init__(self, conv, bn): |
|
super(ConvBN, self).__init__() |
|
self.conv = conv |
|
self.bn = bn |
|
self.fused_weight = None |
|
self.fused_bias = None |
|
|
|
def forward(self, x): |
|
if self.training: |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
else: |
|
if self.fused_weight is not None and self.fused_bias is not None: |
|
x = F.conv2d(x, self.fused_weight, self.fused_bias, |
|
self.conv.stride, self.conv.padding, |
|
self.conv.dilation, self.conv.groups) |
|
else: |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
return x |
|
|
|
def fuse_bn(self): |
|
if self.training: |
|
raise RuntimeError("Call fuse_bn only in eval mode") |
|
|
|
|
|
w = self.conv.weight |
|
mean = self.bn.running_mean |
|
var = torch.sqrt(self.bn.running_var + self.bn.eps) |
|
gamma = self.bn.weight |
|
beta = self.bn.bias |
|
|
|
self.fused_weight = w * (gamma / var).reshape(-1, 1, 1, 1) |
|
self.fused_bias = beta - (gamma * mean / var) |
|
|
|
|
|
class QLBlock(nn.Module): |
|
expansion = 1 |
|
|
|
def __init__( |
|
self, |
|
inplanes, |
|
planes, |
|
stride=1, |
|
downsample=None, |
|
cardinality=1, |
|
base_width=64, |
|
reduce_first=1, |
|
dilation=1, |
|
first_dilation=None, |
|
act_layer=nn.ReLU, |
|
norm_layer=nn.BatchNorm2d, |
|
): |
|
super(QLBlock, self).__init__() |
|
|
|
self.k = 8 if inplanes <= 128 else 4 if inplanes <= 256 else 2 |
|
width = inplanes * self.k |
|
outplanes = inplanes if downsample is None else inplanes * 2 |
|
first_dilation = first_dilation or dilation |
|
|
|
self.conv1 = ConvBN( |
|
nn.Conv2d(inplanes, width*2, kernel_size=1, stride=1, |
|
dilation=first_dilation, groups=1, bias=False), |
|
norm_layer(width*2)) |
|
|
|
|
|
|
|
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, |
|
padding=1, dilation=first_dilation, groups=width, bias=False) |
|
self.bn2 = norm_layer(width) |
|
|
|
self.conv3 = ConvBN( |
|
nn.Conv2d(width, outplanes, kernel_size=1, groups=1, bias=False), |
|
norm_layer(outplanes)) |
|
|
|
self.skip = ConvBN( |
|
nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, |
|
dilation=first_dilation, groups=1, bias=False), |
|
norm_layer(outplanes)) if downsample is not None else nn.Identity() |
|
|
|
self.act3 = hardball(radius2=outplanes) |
|
|
|
def zero_init_last(self): |
|
if getattr(self.conv3.bn, 'weight', None) is not None: |
|
nn.init.zeros_(self.conv3.bn.weight) |
|
|
|
def conv_forward(self, x): |
|
conv = self.conv2 |
|
C = x.size(1) // self.k |
|
kernel = conv.weight.repeat(C, 1, 1, 1) |
|
bias = conv.bias.repeat(C) if conv.bias is not None else None |
|
return F.conv2d(x, kernel, bias, conv.stride, |
|
conv.padding, conv.dilation, x.size(1)) |
|
|
|
def forward(self, x): |
|
x0 = self.skip(x) |
|
x = self.conv1(x) |
|
x = x[:, ::2, :, :] * x[:, 1::2, :, :] |
|
|
|
x = self.conv2(x) |
|
x = self.bn2(x) |
|
x = self.conv3(x) |
|
x += x0 |
|
if self.act3 is not None: |
|
x = self.act3(x) |
|
return x |
|
|
|
def make_blocks( |
|
block_fn, |
|
channels, |
|
block_repeats, |
|
inplanes, |
|
reduce_first=1, |
|
output_stride=32, |
|
down_kernel_size=1, |
|
avg_down=False, |
|
**kwargs, |
|
): |
|
stages = [] |
|
feature_info = [] |
|
net_num_blocks = sum(block_repeats) |
|
net_block_idx = 0 |
|
net_stride = 4 |
|
dilation = prev_dilation = 1 |
|
for stage_idx, (planes, num_blocks) in enumerate(zip(channels, block_repeats)): |
|
stage_name = f'layer{stage_idx + 1}' |
|
stride = 1 if stage_idx == 0 else 2 |
|
if net_stride >= output_stride: |
|
dilation *= stride |
|
stride = 1 |
|
else: |
|
net_stride *= stride |
|
|
|
downsample = None |
|
if stride != 1 or inplanes != planes * block_fn.expansion: |
|
downsample = True |
|
|
|
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, **kwargs) |
|
blocks = [] |
|
for block_idx in range(num_blocks): |
|
downsample = downsample if block_idx == 0 else None |
|
stride = stride if block_idx == 0 else 1 |
|
blocks.append(block_fn( |
|
inplanes, planes, stride, downsample, first_dilation=prev_dilation, |
|
**block_kwargs)) |
|
prev_dilation = dilation |
|
inplanes = planes * block_fn.expansion |
|
net_block_idx += 1 |
|
|
|
stages.append((stage_name, nn.Sequential(*blocks))) |
|
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) |
|
|
|
return stages, feature_info |
|
|
|
|
|
class QLNet(nn.Module): |
|
|
|
|
|
def __init__( |
|
self, |
|
block=QLBlock, |
|
layers=[3,4,12,3], |
|
num_classes=1000, |
|
in_chans=3, |
|
output_stride=32, |
|
global_pool='avg', |
|
cardinality=1, |
|
base_width=64, |
|
stem_width=32, |
|
stem_type='', |
|
replace_stem_pool=False, |
|
block_reduce_first=1, |
|
down_kernel_size=1, |
|
avg_down=False, |
|
act_layer=nn.ReLU, |
|
norm_layer=nn.BatchNorm2d, |
|
zero_init_last=True, |
|
block_args=None, |
|
): |
|
""" |
|
Args: |
|
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck. |
|
layers (List[int]) : number of layers in each block |
|
num_classes (int): number of classification classes (default 1000) |
|
in_chans (int): number of input (color) channels. (default 3) |
|
output_stride (int): output stride of the network, 32, 16, or 8. (default 32) |
|
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') |
|
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1) |
|
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64) |
|
stem_width (int): number of channels in stem convolutions (default 64) |
|
stem_type (str): The type of stem (default ''): |
|
* '', default - a single 7x7 conv with a width of stem_width |
|
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 |
|
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 |
|
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution |
|
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks, |
|
1 for all archs except senets, where 2 (default 1) |
|
down_kernel_size (int): kernel size of residual block downsample path, |
|
1x1 for most, 3x3 for senets (default: 1) |
|
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False) |
|
act_layer (str, nn.Module): activation layer |
|
norm_layer (str, nn.Module): normalization layer |
|
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) |
|
block_args (dict): Extra kwargs to pass through to block module |
|
""" |
|
super(QLNet, self).__init__() |
|
block_args = block_args or dict() |
|
assert output_stride in (8, 16, 32) |
|
self.num_classes = num_classes |
|
self.grad_checkpointing = False |
|
|
|
|
|
deep_stem = 'deep' in stem_type |
|
inplanes = stem_width * 2 if deep_stem else 64 |
|
if deep_stem: |
|
stem_chs = (stem_width, stem_width) |
|
if 'tiered' in stem_type: |
|
stem_chs = (3 * (stem_width // 4), stem_width) |
|
self.conv1 = nn.Sequential(*[ |
|
nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), |
|
norm_layer(stem_chs[0]), |
|
act_layer(inplace=True), |
|
nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), |
|
norm_layer(stem_chs[1]), |
|
act_layer(inplace=True), |
|
nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) |
|
else: |
|
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) |
|
self.bn1 = norm_layer(inplanes) |
|
|
|
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] |
|
|
|
|
|
if replace_stem_pool: |
|
self.maxpool = nn.Sequential(*filter(None, [ |
|
nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1, bias=False), |
|
norm_layer(inplanes), |
|
act_layer(inplace=True) |
|
])) |
|
else: |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
|
|
|
channels = [64, 128, 256, 512] |
|
stage_modules, stage_feature_info = make_blocks( |
|
block, |
|
channels, |
|
layers, |
|
inplanes, |
|
cardinality=cardinality, |
|
base_width=base_width, |
|
output_stride=output_stride, |
|
reduce_first=block_reduce_first, |
|
avg_down=avg_down, |
|
down_kernel_size=down_kernel_size, |
|
act_layer=act_layer, |
|
norm_layer=norm_layer, |
|
**block_args, |
|
) |
|
for stage in stage_modules: |
|
self.add_module(*stage) |
|
self.feature_info.extend(stage_feature_info) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.num_features = 512 * block.expansion |
|
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) |
|
|
|
self.init_weights(zero_init_last=zero_init_last) |
|
|
|
@staticmethod |
|
def from_pretrained(model_name: str, load_weights=True, **kwargs) -> 'ResNet': |
|
entry_fn = model_entrypoint(model_name, 'resnet') |
|
return entry_fn(pretrained=not load_weights, **kwargs) |
|
|
|
@torch.jit.ignore |
|
def init_weights(self, zero_init_last=True): |
|
for n, m in self.named_modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='linear') |
|
|
|
if zero_init_last: |
|
for m in self.modules(): |
|
if hasattr(m, 'zero_init_last'): |
|
m.zero_init_last() |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)') |
|
return matcher |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
self.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self, name_only=False): |
|
return 'fc' if name_only else self.fc |
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'): |
|
self.num_classes = num_classes |
|
self.global_pool, self.fc = create_classifier(self.num_features, 99, |
|
pool_type=global_pool) |
|
|
|
def forward_features(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
|
|
x = self.maxpool(x) |
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True) |
|
else: |
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
return x |
|
|
|
def forward_head(self, x, pre_logits: bool = False): |
|
x = self.global_pool(x) |
|
return x if pre_logits else self.fc(x) |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
|
|
x = self.forward_head(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|