sczhou's picture
init code
320e465
"""
resnet.py - A modified ResNet structure
We append extra channels to the first conv by some network surgery
"""
from collections import OrderedDict
import math
import torch
import torch.nn as nn
from torch.utils import model_zoo
def load_weights_add_extra_dim(target, source_state, extra_dim=1):
new_dict = OrderedDict()
for k1, v1 in target.state_dict().items():
if not 'num_batches_tracked' in k1:
if k1 in source_state:
tar_v = source_state[k1]
if v1.shape != tar_v.shape:
# Init the new segmentation channel with zeros
# print(v1.shape, tar_v.shape)
c, _, w, h = v1.shape
pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
nn.init.orthogonal_(pads)
tar_v = torch.cat([tar_v, pads], 1)
new_dict[k1] = tar_v
target.load_state_dict(new_dict)
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes,
planes,
kernel_size=3,
stride=stride,
dilation=dilation,
padding=dilation,
bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def resnet18(pretrained=True, extra_dim=0):
model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
return model
def resnet50(pretrained=True, extra_dim=0):
model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
if pretrained:
load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
return model