import torch from torch import nn from modules.conv import conv, conv_dw, conv_dw_no_bn class Cpm(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False) self.trunk = nn.Sequential( conv_dw_no_bn(out_channels, out_channels), conv_dw_no_bn(out_channels, out_channels), conv_dw_no_bn(out_channels, out_channels) ) self.conv = conv(out_channels, out_channels, bn=False) def forward(self, x): x = self.align(x) x = self.conv(x + self.trunk(x)) return x class InitialStage(nn.Module): def __init__(self, num_channels, num_heatmaps, num_pafs): super().__init__() self.trunk = nn.Sequential( conv(num_channels, num_channels, bn=False), conv(num_channels, num_channels, bn=False), conv(num_channels, num_channels, bn=False) ) self.heatmaps = nn.Sequential( conv(num_channels, 512, kernel_size=1, padding=0, bn=False), conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False) ) self.pafs = nn.Sequential( conv(num_channels, 512, kernel_size=1, padding=0, bn=False), conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False) ) def forward(self, x): trunk_features = self.trunk(x) heatmaps = self.heatmaps(trunk_features) pafs = self.pafs(trunk_features) return [heatmaps, pafs] class RefinementStageBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False) self.trunk = nn.Sequential( conv(out_channels, out_channels), conv(out_channels, out_channels, dilation=2, padding=2) ) def forward(self, x): initial_features = self.initial(x) trunk_features = self.trunk(initial_features) return initial_features + trunk_features class RefinementStage(nn.Module): def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs): super().__init__() self.trunk = nn.Sequential( RefinementStageBlock(in_channels, out_channels), RefinementStageBlock(out_channels, out_channels), RefinementStageBlock(out_channels, out_channels), RefinementStageBlock(out_channels, out_channels), RefinementStageBlock(out_channels, out_channels) ) self.heatmaps = nn.Sequential( conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False), conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False) ) self.pafs = nn.Sequential( conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False), conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False) ) def forward(self, x): trunk_features = self.trunk(x) heatmaps = self.heatmaps(trunk_features) pafs = self.pafs(trunk_features) return [heatmaps, pafs] class PoseEstimationWithMobileNet(nn.Module): def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38): super().__init__() self.model = nn.Sequential( conv( 3, 32, stride=2, bias=False), conv_dw( 32, 64), conv_dw( 64, 128, stride=2), conv_dw(128, 128), conv_dw(128, 256, stride=2), conv_dw(256, 256), conv_dw(256, 512), # conv4_2 conv_dw(512, 512, dilation=2, padding=2), conv_dw(512, 512), conv_dw(512, 512), conv_dw(512, 512), conv_dw(512, 512) # conv5_5 ) self.cpm = Cpm(512, num_channels) self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs) self.refinement_stages = nn.ModuleList() for idx in range(num_refinement_stages): self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels, num_heatmaps, num_pafs)) def forward(self, x): backbone_features = self.model(x) backbone_features = self.cpm(backbone_features) stages_output = self.initial_stage(backbone_features) for refinement_stage in self.refinement_stages: stages_output.extend( refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1))) return stages_output