Spaces:
Runtime error
Runtime error
# Modified from: | |
# https://github.com/anibali/pytorch-stacked-hourglass | |
# https://github.com/bearpaw/pytorch-pose | |
# Hourglass network inserted in the pre-activated Resnet | |
# Use lr=0.01 for current version | |
# (c) YANG, Wei | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.hub import load_state_dict_from_url | |
__all__ = ['HourglassNet', 'hg'] | |
model_urls = { | |
'hg1': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg1-ce125879.pth', | |
'hg2': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg2-15e342d9.pth', | |
'hg8': 'https://github.com/anibali/pytorch-stacked-hourglass/releases/download/v0.0.0/bearpaw_hg8-90e5d470.pth', | |
} | |
class Bottleneck(nn.Module): | |
expansion = 2 | |
def __init__(self, inplanes, planes, stride=1, downsample=None): | |
super(Bottleneck, self).__init__() | |
self.bn1 = nn.BatchNorm2d(inplanes) | |
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=True) | |
self.bn2 = nn.BatchNorm2d(planes) | |
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | |
padding=1, bias=True) | |
self.bn3 = nn.BatchNorm2d(planes) | |
self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=True) | |
self.relu = nn.ReLU(inplace=True) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
residual = x | |
out = self.bn1(x) | |
out = self.relu(out) | |
out = self.conv1(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn3(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
return out | |
class Hourglass(nn.Module): | |
def __init__(self, block, num_blocks, planes, depth): | |
super(Hourglass, self).__init__() | |
self.depth = depth | |
self.block = block | |
self.hg = self._make_hour_glass(block, num_blocks, planes, depth) | |
def _make_residual(self, block, num_blocks, planes): | |
layers = [] | |
for i in range(0, num_blocks): | |
layers.append(block(planes*block.expansion, planes)) | |
return nn.Sequential(*layers) | |
def _make_hour_glass(self, block, num_blocks, planes, depth): | |
hg = [] | |
for i in range(depth): | |
res = [] | |
for j in range(3): | |
res.append(self._make_residual(block, num_blocks, planes)) | |
if i == 0: | |
res.append(self._make_residual(block, num_blocks, planes)) | |
hg.append(nn.ModuleList(res)) | |
return nn.ModuleList(hg) | |
def _hour_glass_forward(self, n, x): | |
up1 = self.hg[n-1][0](x) | |
low1 = F.max_pool2d(x, 2, stride=2) | |
low1 = self.hg[n-1][1](low1) | |
if n > 1: | |
low2 = self._hour_glass_forward(n-1, low1) | |
else: | |
low2 = self.hg[n-1][3](low1) | |
low3 = self.hg[n-1][2](low2) | |
up2 = F.interpolate(low3, scale_factor=2) | |
out = up1 + up2 | |
return out | |
def forward(self, x): | |
return self._hour_glass_forward(self.depth, x) | |
class HourglassNet(nn.Module): | |
'''Hourglass model from Newell et al ECCV 2016''' | |
def __init__(self, block, num_stacks=2, num_blocks=4, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): | |
super(HourglassNet, self).__init__() | |
self.inplanes = 64 | |
self.num_feats = 128 | |
self.num_stacks = num_stacks | |
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | |
bias=True) | |
self.bn1 = nn.BatchNorm2d(self.inplanes) | |
self.relu = nn.ReLU(inplace=True) | |
self.layer1 = self._make_residual(block, self.inplanes, 1) | |
self.layer2 = self._make_residual(block, self.inplanes, 1) | |
self.layer3 = self._make_residual(block, self.num_feats, 1) | |
self.maxpool = nn.MaxPool2d(2, stride=2) | |
self.upsample_seg = upsample_seg | |
self.add_partseg = add_partseg | |
# build hourglass modules | |
ch = self.num_feats*block.expansion | |
hg, res, fc, score, fc_, score_ = [], [], [], [], [], [] | |
for i in range(num_stacks): | |
hg.append(Hourglass(block, num_blocks, self.num_feats, 4)) | |
res.append(self._make_residual(block, self.num_feats, num_blocks)) | |
fc.append(self._make_fc(ch, ch)) | |
score.append(nn.Conv2d(ch, num_classes, kernel_size=1, bias=True)) | |
if i < num_stacks-1: | |
fc_.append(nn.Conv2d(ch, ch, kernel_size=1, bias=True)) | |
score_.append(nn.Conv2d(num_classes, ch, kernel_size=1, bias=True)) | |
self.hg = nn.ModuleList(hg) | |
self.res = nn.ModuleList(res) | |
self.fc = nn.ModuleList(fc) | |
self.score = nn.ModuleList(score) | |
self.fc_ = nn.ModuleList(fc_) | |
self.score_ = nn.ModuleList(score_) | |
if self.add_partseg: | |
self.hg_ps = (Hourglass(block, num_blocks, self.num_feats, 4)) | |
self.res_ps = (self._make_residual(block, self.num_feats, num_blocks)) | |
self.fc_ps = (self._make_fc(ch, ch)) | |
self.score_ps = (nn.Conv2d(ch, num_partseg, kernel_size=1, bias=True)) | |
self.ups_upsampling_ps = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
if self.upsample_seg: | |
self.ups_upsampling = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
self.ups_conv0 = nn.Conv2d(3, 32, kernel_size=7, stride=1, padding=3, | |
bias=True) | |
self.ups_bn1 = nn.BatchNorm2d(32) | |
self.ups_conv1 = nn.Conv2d(32, 16, kernel_size=7, stride=1, padding=3, | |
bias=True) | |
self.ups_bn2 = nn.BatchNorm2d(16+2) | |
self.ups_conv2 = nn.Conv2d(16+2, 16, kernel_size=5, stride=1, padding=2, | |
bias=True) | |
self.ups_bn3 = nn.BatchNorm2d(16) | |
self.ups_conv3 = nn.Conv2d(16, 2, kernel_size=5, stride=1, padding=2, | |
bias=True) | |
def _make_residual(self, block, planes, blocks, stride=1): | |
downsample = None | |
if stride != 1 or self.inplanes != planes * block.expansion: | |
downsample = nn.Sequential( | |
nn.Conv2d(self.inplanes, planes * block.expansion, | |
kernel_size=1, stride=stride, bias=True), | |
) | |
layers = [] | |
layers.append(block(self.inplanes, planes, stride, downsample)) | |
self.inplanes = planes * block.expansion | |
for i in range(1, blocks): | |
layers.append(block(self.inplanes, planes)) | |
return nn.Sequential(*layers) | |
def _make_fc(self, inplanes, outplanes): | |
bn = nn.BatchNorm2d(inplanes) | |
conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=True) | |
return nn.Sequential( | |
conv, | |
bn, | |
self.relu, | |
) | |
def forward(self, x_in): | |
out = [] | |
out_seg = [] | |
out_partseg = [] | |
x = self.conv1(x_in) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.layer1(x) | |
x = self.maxpool(x) | |
x = self.layer2(x) | |
x = self.layer3(x) | |
for i in range(self.num_stacks): | |
if i == self.num_stacks - 1: | |
if self.add_partseg: | |
y_ps = self.hg_ps(x) | |
y_ps = self.res_ps(y_ps) | |
y_ps = self.fc_ps(y_ps) | |
score_ps = self.score_ps(y_ps) | |
out_partseg.append(score_ps[:, :, :, :]) | |
y = self.hg[i](x) | |
y = self.res[i](y) | |
y = self.fc[i](y) | |
score = self.score[i](y) | |
if self.upsample_seg: | |
out.append(score[:, :-2, :, :]) | |
out_seg.append(score[:, -2:, :, :]) | |
else: | |
out.append(score) | |
if i < self.num_stacks-1: | |
fc_ = self.fc_[i](y) | |
score_ = self.score_[i](score) | |
x = x + fc_ + score_ | |
if self.upsample_seg: | |
# PLAN: add a residual to the upsampled version of the segmentation image | |
# upsample predicted segmentation | |
seg_score = score[:, -2:, :, :] | |
seg_score_256 = self.ups_upsampling(seg_score) | |
# prepare input image | |
ups_img = self.ups_conv0(x_in) | |
ups_img = self.ups_bn1(ups_img) | |
ups_img = self.relu(ups_img) | |
ups_img = self.ups_conv1(ups_img) | |
# import pdb; pdb.set_trace() | |
ups_conc = torch.cat((seg_score_256, ups_img), 1) | |
# ups_conc = self.ups_bn2(ups_conc) | |
ups_conc = self.relu(ups_conc) | |
ups_conc = self.ups_conv2(ups_conc) | |
ups_conc = self.ups_bn3(ups_conc) | |
ups_conc = self.relu(ups_conc) | |
correction = self.ups_conv3(ups_conc) | |
seg_final = seg_score_256 + correction | |
if self.add_partseg: | |
partseg_final = self.ups_upsampling_ps(score_ps) | |
out_dict = {'out_list_kp': out, | |
'out_list_seg': out, | |
'seg_final': seg_final, | |
'out_list_partseg': out_partseg, | |
'partseg_final': partseg_final | |
} | |
return out_dict | |
else: | |
out_dict = {'out_list_kp': out, | |
'out_list_seg': out, | |
'seg_final': seg_final | |
} | |
return out_dict | |
return out | |
def hg(**kwargs): | |
model = HourglassNet(Bottleneck, num_stacks=kwargs['num_stacks'], num_blocks=kwargs['num_blocks'], | |
num_classes=kwargs['num_classes'], upsample_seg=kwargs['upsample_seg'], | |
add_partseg=kwargs['add_partseg'], num_partseg=kwargs['num_partseg']) | |
return model | |
def _hg(arch, pretrained, progress, **kwargs): | |
model = hg(**kwargs) | |
if pretrained: | |
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) | |
model.load_state_dict(state_dict) | |
return model | |
def hg1(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): | |
return _hg('hg1', pretrained, progress, num_stacks=1, num_blocks=num_blocks, | |
num_classes=num_classes, upsample_seg=upsample_seg, | |
add_partseg=add_partseg, num_partseg=num_partseg) | |
def hg2(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): | |
return _hg('hg2', pretrained, progress, num_stacks=2, num_blocks=num_blocks, | |
num_classes=num_classes, upsample_seg=upsample_seg, | |
add_partseg=add_partseg, num_partseg=num_partseg) | |
def hg4(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): | |
return _hg('hg4', pretrained, progress, num_stacks=4, num_blocks=num_blocks, | |
num_classes=num_classes, upsample_seg=upsample_seg, | |
add_partseg=add_partseg, num_partseg=num_partseg) | |
def hg8(pretrained=False, progress=True, num_blocks=1, num_classes=16, upsample_seg=False, add_partseg=False, num_partseg=None): | |
return _hg('hg8', pretrained, progress, num_stacks=8, num_blocks=num_blocks, | |
num_classes=num_classes, upsample_seg=upsample_seg, | |
add_partseg=add_partseg, num_partseg=num_partseg) | |