File size: 4,196 Bytes
d015578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from torch import nn
from spiga.models.cnn.layers import Conv, Residual
from spiga.models.cnn.hourglass import HourglassCore
from spiga.models.cnn.coord_conv import AddCoordsTh
from spiga.models.cnn.transform_e2p import E2Ptransform


class MultitaskCNN(nn.Module):

    def __init__(self, nstack=4, num_landmarks=98, num_edges=15, pose_req=True, **kwargs):
        super(MultitaskCNN, self).__init__()

        # Parameters
        self.img_res = 256                  # WxH input resolution
        self.ch_dim = 256                   # Default channel dimension
        self.out_res = 64                   # WxH output resolution
        self.nstack = nstack                # Hourglass modules stacked
        self.num_landmarks = num_landmarks  # Number of landmarks
        self.num_edges = num_edges          # Number of edges subsets (eyeR, eyeL, nose, etc)
        self.pose_required = pose_req       # Multitask flag

        # Image preprocessing
        self.pre = nn.Sequential(
            AddCoordsTh(x_dim=self.img_res, y_dim=self.img_res, with_r=True),
            Conv(6, 64, 7, 2, bn=True, relu=True),
            Residual(64, 128),
            Conv(128, 128, 2, 2, bn=True, relu=True),
            Residual(128, 128),
            Residual(128, self.ch_dim)
        )

        # Hourglass modules
        self.hgs = nn.ModuleList([HourglassCore(4, self.ch_dim) for i in range(self.nstack)])
        self.hgs_out = nn.ModuleList([
            nn.Sequential(
                Residual(self.ch_dim, self.ch_dim),
                Conv(self.ch_dim, self.ch_dim, 1, bn=True, relu=True)
            ) for i in range(nstack)])
        if self.pose_required:
            self.hgs_core = nn.ModuleList([
                nn.Sequential(
                    Residual(self.ch_dim, self.ch_dim),
                    Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True),
                    Residual(self.ch_dim, self.ch_dim),
                    Conv(self.ch_dim, self.ch_dim, 2, 2, bn=True, relu=True)
                ) for i in range(nstack)])

        # Attention module (ADnet style)
        self.outs_points = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False),
                                                        nn.Sigmoid()) for i in range(self.nstack - 1)])
        self.outs_edges = nn.ModuleList([nn.Sequential(Conv(self.ch_dim, self.num_edges, 1, relu=False, bn=False),
                                                       nn.Sigmoid()) for i in range(self.nstack - 1)])
        self.E2Ptransform = E2Ptransform(self.num_landmarks, self.num_edges, out_dim=self.out_res)

        self.outs_features = nn.ModuleList([Conv(self.ch_dim, self.num_landmarks, 1, relu=False, bn=False)for i in range(self.nstack - 1)])

        # Stacked Hourglass inputs (nstack > 1)
        self.merge_preds = nn.ModuleList([Conv(self.num_landmarks, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)])
        self.merge_features = nn.ModuleList([Conv(self.ch_dim, self.ch_dim, 1, relu=False, bn=False) for i in range(self.nstack - 1)])

    def forward(self, imgs):

        x = self.pre(imgs)
        outputs = {'VisualField': [],
                   'HGcore': []}

        core_raw = []
        for i in range(self.nstack):
            # Hourglass
            hg, core_raw = self.hgs[i](x, core=core_raw)
            if self.pose_required:
                core = self.hgs_core[i](core_raw[-self.hgs[i].n])
                outputs['HGcore'].append(core)
            hg = self.hgs_out[i](hg)

            # Visual features
            outputs['VisualField'].append(hg)

            # Prepare next stacked input
            if i < self.nstack - 1:
                # Attentional modules
                points = self.outs_points[i](hg)
                edges = self.outs_edges[i](hg)
                edges_ext = self.E2Ptransform(edges)
                point_edges = points * edges_ext

                # Landmarks
                maps = self.outs_features[i](hg)
                preds = maps * point_edges

                # Outputs
                x = x + self.merge_preds[i](preds) + self.merge_features[i](hg)

        return outputs