AttentionMaps / resnet.py
TexR6's picture
initial commit
d7b0f75
import torch
import torch.nn as nn
import torch.nn.functional as F
from os.path import join as pjoin
from collections import OrderedDict
def weight_standardize(w, dim, eps):
"""Subtracts mean and divides by standard deviation."""
w = w - torch.mean(w, dim=dim)
w = w / (torch.std(w, dim=dim) + eps)
return w
def np2th(weights, conv=False):
"""Possibly convert HWIO to OIHW."""
if conv:
weights = weights.transpose([3, 2, 0, 1])
return torch.from_numpy(weights)
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = weight_standardize(self.weight, [0, 1, 2], 1e-5)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def conv3x3(in_channels, out_channels, stride=1, groups=1, bias=False):
return StdConv2d(in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=bias,
groups=groups)
def conv1x1(in_channels, out_channels, stride=1, bias=False):
return StdConv2d(in_channels,
out_channels,
kernel_size=1,
stride=stride,
padding=0,
bias=bias)
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block.
"""
def __init__(self,
in_channels,
out_channels=None,
mid_channels=None,
stride=1):
super().__init__()
out_channels = out_channels or in_channels
mid_channels = mid_channels or out_channels // 4
self.gn1 = nn.GroupNorm(32, mid_channels, eps=1e-6)
self.conv1 = conv1x1(in_channels, mid_channels, bias=False)
self.gn2 = nn.GroupNorm(32, mid_channels, eps=1e-6)
self.conv2 = conv3x3(mid_channels, mid_channels, stride,
bias=False) # Original code has it on conv1!!
self.gn3 = nn.GroupNorm(32, out_channels, eps=1e-6)
self.conv3 = conv1x1(mid_channels, out_channels, bias=False)
self.relu = nn.ReLU(inplace=True)
if (stride != 1 or in_channels != out_channels):
# Projection also with pre-activation according to paper.
self.downsample = conv1x1(in_channels,
out_channels,
stride,
bias=False)
self.gn_proj = nn.GroupNorm(out_channels, out_channels)
def forward(self, x):
# Residual branch
residual = x
if hasattr(self, 'downsample'):
residual = self.downsample(x)
residual = self.gn_proj(residual)
# Unit's branch
y = self.relu(self.gn1(self.conv1(x)))
y = self.relu(self.gn2(self.conv2(y)))
y = self.gn3(self.conv3(y))
y = self.relu(residual + y)
return y
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode."""
def __init__(self, block_units, width_factor):
super().__init__()
width = int(64 * width_factor)
self.width = width
self.downsample = 16 # four stride=2 conv2d layer
# The following will be unreadable if we split lines.
# pylint: disable=line-too-long
self.root = nn.Sequential(
OrderedDict([('conv',
StdConv2d(3,
width,
kernel_size=7,
stride=2,
bias=False,
padding=3)),
('gn', nn.GroupNorm(32, width, eps=1e-6)),
('relu', nn.ReLU(inplace=True)),
('pool',
nn.MaxPool2d(kernel_size=3, stride=2, padding=0))]))
self.body = nn.Sequential(
OrderedDict([
('block1',
nn.Sequential(
OrderedDict([('unit1',
PreActBottleneck(in_channels=width,
out_channels=width * 4,
mid_channels=width))] +
[(f'unit{i:d}',
PreActBottleneck(in_channels=width * 4,
out_channels=width * 4,
mid_channels=width))
for i in range(2, block_units[0] + 1)], ))),
('block2',
nn.Sequential(
OrderedDict([('unit1',
PreActBottleneck(in_channels=width * 4,
out_channels=width * 8,
mid_channels=width * 2,
stride=2))] +
[(f'unit{i:d}',
PreActBottleneck(in_channels=width * 8,
out_channels=width * 8,
mid_channels=width * 2))
for i in range(2, block_units[1] + 1)], ))),
('block3',
nn.Sequential(
OrderedDict([('unit1',
PreActBottleneck(in_channels=width * 8,
out_channels=width * 16,
mid_channels=width * 4,
stride=2))] +
[(f'unit{i:d}',
PreActBottleneck(in_channels=width * 16,
out_channels=width * 16,
mid_channels=width * 4))
for i in range(2, block_units[2] + 1)], ))),
]))
def forward(self, x):
x = self.root(x)
x = self.body(x)
return x
def resnet50():
return ResNetV2(block_units=(3, 4, 9), width_factor=1)