HEAT / models /stacked_hg.py
Egrt's picture
init
424188c
"""
Hourglass network inserted in the pre-activated Resnet
Use lr=0.01 for current version
(c) Nan Xue (HAWP)
(c) Yichao Zhou (LCNN)
(c) YANG, Wei
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ["HourglassNet", "hg"]
class Bottleneck2D(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck2D, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class Hourglass(nn.Module):
def __init__(self, block, num_blocks, planes, depth):
super(Hourglass, self).__init__()
self.depth = depth
self.block = block
self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
def _make_residual(self, block, num_blocks, planes):
layers = []
for i in range(0, num_blocks):
layers.append(block(planes * block.expansion, planes))
return nn.Sequential(*layers)
def _make_hour_glass(self, block, num_blocks, planes, depth):
hg = []
for i in range(depth):
res = []
for j in range(3):
res.append(self._make_residual(block, num_blocks, planes))
if i == 0:
res.append(self._make_residual(block, num_blocks, planes))
hg.append(nn.ModuleList(res))
return nn.ModuleList(hg)
def _hour_glass_forward(self, n, x):
up1 = self.hg[n - 1][0](x)
low1 = F.max_pool2d(x, 2, stride=2)
low1 = self.hg[n - 1][1](low1)
if n > 1:
low2 = self._hour_glass_forward(n - 1, low1)
else:
low2 = self.hg[n - 1][3](low1)
low3 = self.hg[n - 1][2](low2)
up2 = F.interpolate(low3, scale_factor=2)
out = up1 + up2
return out
def forward(self, x):
return self._hour_glass_forward(self.depth, x)
class HourglassNet(nn.Module):
"""Hourglass model from Newell et al ECCV 2016"""
def __init__(self, inplanes, num_feats, block, head, depth, num_stacks, num_blocks, num_classes):
super(HourglassNet, self).__init__()
self.inplanes = inplanes
self.num_feats = num_feats
self.num_stacks = num_stacks
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_residual(block, self.inplanes, 1)
self.layer2 = self._make_residual(block, self.inplanes, 1)
self.layer3 = self._make_residual(block, self.num_feats, 1)
self.maxpool = nn.MaxPool2d(2, stride=2)
# build hourglass modules
ch = self.num_feats * block.expansion
# vpts = []
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
for i in range(num_stacks):
hg.append(Hourglass(block, num_blocks, self.num_feats, depth))
res.append(self._make_residual(block, self.num_feats, num_blocks))
fc.append(self._make_fc(ch, ch))
score.append(head(ch, num_classes))
# vpts.append(VptsHead(ch))
# vpts.append(nn.Linear(ch, 9))
# score.append(nn.Conv2d(ch, num_classes, kernel_size=1))
# score[i].bias.data[0] += 4.6
# score[i].bias.data[2] += 4.6
if i < num_stacks - 1:
fc_.append(nn.Conv2d(ch, ch, kernel_size=1))
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
# self.vpts = nn.ModuleList(vpts)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)
def _make_residual(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
)
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_fc(self, inplanes, outplanes):
bn = nn.BatchNorm2d(inplanes)
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1)
return nn.Sequential(conv, bn, self.relu)
def forward(self, x):
out = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.maxpool(x)
x = self.layer2(x)
x = self.layer3(x)
for i in range(self.num_stacks):
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
out.append(score)
if i < self.num_stacks - 1:
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_
return out[::-1], y
def train(self, mode=True):
# Override train so that the training mode is set as we want
nn.Module.train(self, mode)
if mode:
# fix all bn layers
def set_bn_eval(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
self.apply(set_bn_eval)
class MultitaskHead(nn.Module):
def __init__(self, input_channels, num_class, head_size):
super(MultitaskHead, self).__init__()
m = int(input_channels / 4)
heads = []
for output_channels in sum(head_size, []):
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(m, output_channels, kernel_size=1),
)
)
self.heads = nn.ModuleList(heads)
assert num_class == sum(sum(head_size, []))
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=1)
def build_hg():
inplanes = 64
num_feats = 256 //2
depth = 4
num_stacks = 2
num_blocks = 1
head_size = [[2], [2]]
out_feature_channels = 256
num_class = sum(sum(head_size, []))
model = HourglassNet(
block=Bottleneck2D,
inplanes = inplanes,
num_feats= num_feats,
depth=depth,
head=lambda c_in, c_out: MultitaskHead(c_in, c_out, head_size=head_size),
num_stacks = num_stacks,
num_blocks = num_blocks,
num_classes = num_class)
model.out_feature_channels = out_feature_channels
return model