from torch import nn from spiga.models.cnn.layers import Conv, Residual from spiga.models.cnn.hourglass import HourglassCore from spiga.models.cnn.coord_conv import AddCoordsTh from spiga.models.cnn.transform_e2p import E2Ptransform class MultitaskCNN(nn.Module): def __init__(self, nstack=4, num_landmarks=98, num_edges=15, pose_req=True, **kwargs): super(MultitaskCNN, self).__init__() # Parameters self.img_res = 256 # WxH input resolution self.ch_dim = 256 # Default channel dimension self.out_res = 64 # WxH output resolution self.nstack = nstack # Hourglass modules stacked self.num_landmarks = num_landmarks # Number of landmarks self.num_edges = num_edges # Number of edges subsets (eyeR, eyeL, nose, etc) self.pose_required = pose_req # Multitask flag # Image preprocessing self.pre = nn.Sequential( AddCoordsTh(x_dim=self.img_res, y_dim=self.img_res, with_r=True), Conv(6, 64, 7, 2, bn=True, relu=True), Residual(64, 128), Conv(128, 128, 2, 2, bn=True, relu=True), Residual(128, 128), Residual(128, self.ch_dim) ) # Hourglass modules self.hgs = nn.ModuleList([HourglassCore(4, self.ch_dim) for i in range(self.nstack)]) self.hgs_out = nn.ModuleList([ nn.Sequential( Residual(self.ch_dim, self.ch_dim), Conv(self.ch_dim, self.ch_dim, 1, bn=True, relu=True) ) for i in range(nstack)]) if self.pose_required: self.hgs_core = nn.ModuleList([ nn.Sequential( Residual(self.ch_dim, self.ch_dim), Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True), Residual(self.ch_dim, self.ch_dim), Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True) ) for i in range(nstack)]) # Attention module (ADnet style) self.outs_points = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False), nn.Sigmoid()) for i in range(self.nstack - 1)]) self.outs_edges = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_edges, 1, relu=False, bn=False), nn.Sigmoid()) for i in range(self.nstack - 1)]) self.E2Ptransform = E2Ptransform(self.num_landmarks, self.num_edges, out_dim=self.out_res) self.outs_features = nn.ModuleList([Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False)for i in range(self.nstack - 1)]) # Stacked Hourglass inputs (nstack > 1) self.merge_preds = nn.ModuleList([Conv(self.num_landmarks, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)]) self.merge_features = nn.ModuleList([Conv(self.ch_dim, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)]) def forward(self, imgs): x = self.pre(imgs) outputs = {'VisualField': [], 'HGcore': []} core_raw = [] for i in range(self.nstack): # Hourglass hg, core_raw = self.hgs[i](x, core=core_raw) if self.pose_required: core = self.hgs_core[i](core_raw[-self.hgs[i].n]) outputs['HGcore'].append(core) hg = self.hgs_out[i](hg) # Visual features outputs['VisualField'].append(hg) # Prepare next stacked input if i < self.nstack - 1: # Attentional modules points = self.outs_points[i](hg) edges = self.outs_edges[i](hg) edges_ext = self.E2Ptransform(edges) point_edges = points * edges_ext # Landmarks maps = self.outs_features[i](hg) preds = maps * point_edges # Outputs x = x + self.merge_preds[i](preds) + self.merge_features[i](hg) return outputs