Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from os.path import join as pjoin | |
from collections import OrderedDict | |
def weight_standardize(w, dim, eps): | |
"""Subtracts mean and divides by standard deviation.""" | |
w = w - torch.mean(w, dim=dim) | |
w = w / (torch.std(w, dim=dim) + eps) | |
return w | |
def np2th(weights, conv=False): | |
"""Possibly convert HWIO to OIHW.""" | |
if conv: | |
weights = weights.transpose([3, 2, 0, 1]) | |
return torch.from_numpy(weights) | |
class StdConv2d(nn.Conv2d): | |
def forward(self, x): | |
w = weight_standardize(self.weight, [0, 1, 2], 1e-5) | |
return F.conv2d(x, w, self.bias, self.stride, self.padding, | |
self.dilation, self.groups) | |
def conv3x3(in_channels, out_channels, stride=1, groups=1, bias=False): | |
return StdConv2d(in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=1, | |
bias=bias, | |
groups=groups) | |
def conv1x1(in_channels, out_channels, stride=1, bias=False): | |
return StdConv2d(in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=stride, | |
padding=0, | |
bias=bias) | |
class PreActBottleneck(nn.Module): | |
"""Pre-activation (v2) bottleneck block. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels=None, | |
mid_channels=None, | |
stride=1): | |
super().__init__() | |
out_channels = out_channels or in_channels | |
mid_channels = mid_channels or out_channels // 4 | |
self.gn1 = nn.GroupNorm(32, mid_channels, eps=1e-6) | |
self.conv1 = conv1x1(in_channels, mid_channels, bias=False) | |
self.gn2 = nn.GroupNorm(32, mid_channels, eps=1e-6) | |
self.conv2 = conv3x3(mid_channels, mid_channels, stride, | |
bias=False) # Original code has it on conv1!! | |
self.gn3 = nn.GroupNorm(32, out_channels, eps=1e-6) | |
self.conv3 = conv1x1(mid_channels, out_channels, bias=False) | |
self.relu = nn.ReLU(inplace=True) | |
if (stride != 1 or in_channels != out_channels): | |
# Projection also with pre-activation according to paper. | |
self.downsample = conv1x1(in_channels, | |
out_channels, | |
stride, | |
bias=False) | |
self.gn_proj = nn.GroupNorm(out_channels, out_channels) | |
def forward(self, x): | |
# Residual branch | |
residual = x | |
if hasattr(self, 'downsample'): | |
residual = self.downsample(x) | |
residual = self.gn_proj(residual) | |
# Unit's branch | |
y = self.relu(self.gn1(self.conv1(x))) | |
y = self.relu(self.gn2(self.conv2(y))) | |
y = self.gn3(self.conv3(y)) | |
y = self.relu(residual + y) | |
return y | |
class ResNetV2(nn.Module): | |
"""Implementation of Pre-activation (v2) ResNet mode.""" | |
def __init__(self, block_units, width_factor): | |
super().__init__() | |
width = int(64 * width_factor) | |
self.width = width | |
self.downsample = 16 # four stride=2 conv2d layer | |
# The following will be unreadable if we split lines. | |
# pylint: disable=line-too-long | |
self.root = nn.Sequential( | |
OrderedDict([('conv', | |
StdConv2d(3, | |
width, | |
kernel_size=7, | |
stride=2, | |
bias=False, | |
padding=3)), | |
('gn', nn.GroupNorm(32, width, eps=1e-6)), | |
('relu', nn.ReLU(inplace=True)), | |
('pool', | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=0))])) | |
self.body = nn.Sequential( | |
OrderedDict([ | |
('block1', | |
nn.Sequential( | |
OrderedDict([('unit1', | |
PreActBottleneck(in_channels=width, | |
out_channels=width * 4, | |
mid_channels=width))] + | |
[(f'unit{i:d}', | |
PreActBottleneck(in_channels=width * 4, | |
out_channels=width * 4, | |
mid_channels=width)) | |
for i in range(2, block_units[0] + 1)], ))), | |
('block2', | |
nn.Sequential( | |
OrderedDict([('unit1', | |
PreActBottleneck(in_channels=width * 4, | |
out_channels=width * 8, | |
mid_channels=width * 2, | |
stride=2))] + | |
[(f'unit{i:d}', | |
PreActBottleneck(in_channels=width * 8, | |
out_channels=width * 8, | |
mid_channels=width * 2)) | |
for i in range(2, block_units[1] + 1)], ))), | |
('block3', | |
nn.Sequential( | |
OrderedDict([('unit1', | |
PreActBottleneck(in_channels=width * 8, | |
out_channels=width * 16, | |
mid_channels=width * 4, | |
stride=2))] + | |
[(f'unit{i:d}', | |
PreActBottleneck(in_channels=width * 16, | |
out_channels=width * 16, | |
mid_channels=width * 4)) | |
for i in range(2, block_units[2] + 1)], ))), | |
])) | |
def forward(self, x): | |
x = self.root(x) | |
x = self.body(x) | |
return x | |
def resnet50(): | |
return ResNetV2(block_units=(3, 4, 9), width_factor=1) | |