Nadine Rueegg
initial commit for barc
7629b39
raw
history blame
11.7 kB
# Modified from:
# https://github.com/anibali/pytorch-stacked-hourglass
# https://github.com/bearpaw/pytorch-pose
# Hourglass network inserted in the pre-activated Resnet
# Use lr=0.01 for current version
# (c) YANG, Wei
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
__all__ = ['HourglassNet', 'hg']
model_urls = {
'hg1': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg1-ce125879.pth',
'hg2': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg2-15e342d9.pth',
'hg8': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg8-90e5d470.pth',
}
class Bottleneck(nn.Module):
expansion = 2
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=True)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
class Hourglass(nn.Module):
def __init__(self, block, num_blocks, planes, depth):
super(Hourglass, self).__init__()
self.depth = depth
self.block = block
self.hg = self._make_hour_glass(block, num_blocks, planes, depth)
def _make_residual(self, block, num_blocks, planes):
layers = []
for i in range(0, num_blocks):
layers.append(block(planes*block.expansion, planes))
return nn.Sequential(*layers)
def _make_hour_glass(self, block, num_blocks, planes, depth):
hg = []
for i in range(depth):
res = []
for j in range(3):
res.append(self._make_residual(block, num_blocks, planes))
if i == 0:
res.append(self._make_residual(block, num_blocks, planes))
hg.append(nn.ModuleList(res))
return nn.ModuleList(hg)
def _hour_glass_forward(self, n, x):
up1 = self.hg[n-1][0](x)
low1 = F.max_pool2d(x, 2, stride=2)
low1 = self.hg[n-1][1](low1)
if n > 1:
low2 = self._hour_glass_forward(n-1, low1)
else:
low2 = self.hg[n-1][3](low1)
low3 = self.hg[n-1][2](low2)
up2 = F.interpolate(low3, scale_factor=2)
out = up1 + up2
return out
def forward(self, x):
return self._hour_glass_forward(self.depth, x)
class HourglassNet(nn.Module):
'''Hourglass model from Newell et al ECCV 2016'''
def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
super(HourglassNet, self).__init__()
self.inplanes = 64
self.num_feats = 128
self.num_stacks = num_stacks
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=True)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_residual(block, self.inplanes, 1)
self.layer2 = self._make_residual(block, self.inplanes, 1)
self.layer3 = self._make_residual(block, self.num_feats, 1)
self.maxpool = nn.MaxPool2d(2, stride=2)
self.upsample_seg = upsample_seg
self.add_partseg = add_partseg
# build hourglass modules
ch = self.num_feats*block.expansion
hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
for i in range(num_stacks):
hg.append(Hourglass(block, num_blocks, self.num_feats, 4))
res.append(self._make_residual(block, self.num_feats, num_blocks))
fc.append(self._make_fc(ch, ch))
score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True))
if i < num_stacks-1:
fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True))
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True))
self.hg = nn.ModuleList(hg)
self.res = nn.ModuleList(res)
self.fc = nn.ModuleList(fc)
self.score = nn.ModuleList(score)
self.fc_ = nn.ModuleList(fc_)
self.score_ = nn.ModuleList(score_)
if self.add_partseg:
self.hg_ps = (Hourglass(block, num_blocks, self.num_feats, 4))
self.res_ps = (self._make_residual(block, self.num_feats, num_blocks))
self.fc_ps = (self._make_fc(ch, ch))
self.score_ps = (nn.Conv2d(ch, num_partseg, kernel_size=1, bias=True))
self.ups_upsampling_ps = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
if self.upsample_seg:
self.ups_upsampling = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.ups_conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3,
bias=True)
self.ups_bn1 = nn.BatchNorm2d(32)
self.ups_conv1 = nn.Conv2d(32, 16, kernel_size=7, stride=1, padding=3,
bias=True)
self.ups_bn2 = nn.BatchNorm2d(16+2)
self.ups_conv2 = nn.Conv2d(16+2, 16, kernel_size=5, stride=1, padding=2,
bias=True)
self.ups_bn3 = nn.BatchNorm2d(16)
self.ups_conv3 = nn.Conv2d(16, 2, kernel_size=5, stride=1, padding=2,
bias=True)
def _make_residual(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=True),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _make_fc(self, inplanes, outplanes):
bn = nn.BatchNorm2d(inplanes)
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True)
return nn.Sequential(
conv,
bn,
self.relu,
)
def forward(self, x_in):
out = []
out_seg = []
out_partseg = []
x = self.conv1(x_in)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.maxpool(x)
x = self.layer2(x)
x = self.layer3(x)
for i in range(self.num_stacks):
if i == self.num_stacks - 1:
if self.add_partseg:
y_ps = self.hg_ps(x)
y_ps = self.res_ps(y_ps)
y_ps = self.fc_ps(y_ps)
score_ps = self.score_ps(y_ps)
out_partseg.append(score_ps[:, :, :, :])
y = self.hg[i](x)
y = self.res[i](y)
y = self.fc[i](y)
score = self.score[i](y)
if self.upsample_seg:
out.append(score[:, :-2, :, :])
out_seg.append(score[:, -2:, :, :])
else:
out.append(score)
if i < self.num_stacks-1:
fc_ = self.fc_[i](y)
score_ = self.score_[i](score)
x = x + fc_ + score_
if self.upsample_seg:
# PLAN: add a residual to the upsampled version of the segmentation image
# upsample predicted segmentation
seg_score = score[:, -2:, :, :]
seg_score_256 = self.ups_upsampling(seg_score)
# prepare input image
ups_img = self.ups_conv0(x_in)
ups_img = self.ups_bn1(ups_img)
ups_img = self.relu(ups_img)
ups_img = self.ups_conv1(ups_img)
# import pdb; pdb.set_trace()
ups_conc = torch.cat((seg_score_256, ups_img), 1)
# ups_conc = self.ups_bn2(ups_conc)
ups_conc = self.relu(ups_conc)
ups_conc = self.ups_conv2(ups_conc)
ups_conc = self.ups_bn3(ups_conc)
ups_conc = self.relu(ups_conc)
correction = self.ups_conv3(ups_conc)
seg_final = seg_score_256 + correction
if self.add_partseg:
partseg_final = self.ups_upsampling_ps(score_ps)
out_dict = {'out_list_kp': out,
'out_list_seg': out,
'seg_final': seg_final,
'out_list_partseg': out_partseg,
'partseg_final': partseg_final
}
return out_dict
else:
out_dict = {'out_list_kp': out,
'out_list_seg': out,
'seg_final': seg_final
}
return out_dict
return out
def hg(**kwargs):
model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'],
num_classes=kwargs['num_classes'], upsample_seg=kwargs['upsample_seg'],
add_partseg=kwargs['add_partseg'], num_partseg=kwargs['num_partseg'])
return model
def _hg(arch, pretrained, progress, **kwargs):
model = hg(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def hg1(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
return _hg('hg1', pretrained, progress, num_stacks=1, num_blocks=num_blocks,
num_classes=num_classes, upsample_seg=upsample_seg,
add_partseg=add_partseg, num_partseg=num_partseg)
def hg2(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
return _hg('hg2', pretrained, progress, num_stacks=2, num_blocks=num_blocks,
num_classes=num_classes, upsample_seg=upsample_seg,
add_partseg=add_partseg, num_partseg=num_partseg)
def hg4(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
return _hg('hg4', pretrained, progress, num_stacks=4, num_blocks=num_blocks,
num_classes=num_classes, upsample_seg=upsample_seg,
add_partseg=add_partseg, num_partseg=num_partseg)
def hg8(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None):
return _hg('hg8', pretrained, progress, num_stacks=8, num_blocks=num_blocks,
num_classes=num_classes, upsample_seg=upsample_seg,
add_partseg=add_partseg, num_partseg=num_partseg)