zejunyang
update
9667e74
import torch
from torch import nn
from modules.conv import conv, conv_dw, conv_dw_no_bn
class Cpm(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
self.trunk = nn.Sequential(
conv_dw_no_bn(out_channels, out_channels),
conv_dw_no_bn(out_channels, out_channels),
conv_dw_no_bn(out_channels, out_channels)
)
self.conv = conv(out_channels, out_channels, bn=False)
def forward(self, x):
x = self.align(x)
x = self.conv(x + self.trunk(x))
return x
class InitialStage(nn.Module):
def __init__(self, num_channels, num_heatmaps, num_pafs):
super().__init__()
self.trunk = nn.Sequential(
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False),
conv(num_channels, num_channels, bn=False)
)
self.heatmaps = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
self.pafs = nn.Sequential(
conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
class RefinementStageBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
self.trunk = nn.Sequential(
conv(out_channels, out_channels),
conv(out_channels, out_channels, dilation=2, padding=2)
)
def forward(self, x):
initial_features = self.initial(x)
trunk_features = self.trunk(initial_features)
return initial_features + trunk_features
class RefinementStage(nn.Module):
def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
super().__init__()
self.trunk = nn.Sequential(
RefinementStageBlock(in_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels),
RefinementStageBlock(out_channels, out_channels)
)
self.heatmaps = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
)
self.pafs = nn.Sequential(
conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
)
def forward(self, x):
trunk_features = self.trunk(x)
heatmaps = self.heatmaps(trunk_features)
pafs = self.pafs(trunk_features)
return [heatmaps, pafs]
class PoseEstimationWithMobileNet(nn.Module):
def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
super().__init__()
self.model = nn.Sequential(
conv( 3, 32, stride=2, bias=False),
conv_dw( 32, 64),
conv_dw( 64, 128, stride=2),
conv_dw(128, 128),
conv_dw(128, 256, stride=2),
conv_dw(256, 256),
conv_dw(256, 512), # conv4_2
conv_dw(512, 512, dilation=2, padding=2),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512),
conv_dw(512, 512) # conv5_5
)
self.cpm = Cpm(512, num_channels)
self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
self.refinement_stages = nn.ModuleList()
for idx in range(num_refinement_stages):
self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
num_heatmaps, num_pafs))
def forward(self, x):
backbone_features = self.model(x)
backbone_features = self.cpm(backbone_features)
stages_output = self.initial_stage(backbone_features)
for refinement_stage in self.refinement_stages:
stages_output.extend(
refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))
return stages_output