File size: 7,407 Bytes
da7cd93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch.nn as nn
import torch.nn.functional as F


class TemporalModelBase(nn.Module):
    """

    Do not instantiate this class.

    """

    def __init__(self, num_joints_in, in_features, num_joints_out,

                 filter_widths, causal, dropout, channels, sagittal=0, freezing=0, fusion=0):
        super().__init__()

        # Validate input
        for fw in filter_widths:
            assert fw % 2 != 0, 'Only odd filter widths are supported'

        self.num_joints_in = num_joints_in
        self.in_features = in_features
        self.num_joints_out = 2
        self.filter_widths = filter_widths

        # Initialize layers
        self.drop = nn.Dropout(dropout)         #for regularization
        self.relu = nn.ReLU(inplace=True)       #introducing non-linearity

        self.pad = [ filter_widths[0] // 2 ]
        self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1)         #normalize channels
        self.expand_bn2 = nn.BatchNorm2d(in_features, momentum=0.1)     #normalize features

        self.shrink = nn.Conv1d(channels, num_joints_out, 1)
        self.flatten = nn.Flatten()

    def set_bn_momentum(self, momentum):
        """

        Batch normalization is a technique used to normalize the inputs of each layer, which helps stabilize and speed up the training process. The momentum parameter determines how much of the statistics from the current batch should contribute to the running mean and variance of the batch normalization layer.

        In some cases, especially during fine-tuning or transfer learning, it might be beneficial to adjust the momentum dynamically. For example, when fine-tuning on a new dataset, you may want to decrease the momentum to adapt faster to the new data distribution. This method provides a way to adjust the momentum value dynamically during training.

        """
        self.expand_bn.momentum = momentum
        for bn in self.layers_bn:
            bn.momentum = momentum

    def receptive_field(self):
        """

        Return the total receptive field of this model as # of frames.

        The receptive field represents the region in the input data that influences a particular output of the neural network. It is determined by the filter sizes and strides of the convolutional layers in the network.

        """
        frames = 0
        for f in self.pad:
            frames += f
        return 1 + 2*frames

    def total_causal_shift(self):
        """

        Return the asymmetric offset for sequence padding.

        The returned value is typically 0 if causal convolutions are disabled,

        otherwise it is half the receptive field.

        """
        frames = self.causal_shift[0]
        next_dilation = self.filter_widths[0]
        for i in range(1, len(self.filter_widths)):
            frames += self.causal_shift[i] * next_dilation
            next_dilation *= self.filter_widths[i]
        return frames

    def forward(self, x):

        x = self._forward_blocks(x)
        return x

class TemporalModel(TemporalModelBase):
    """

    Reference 3D pose estimation model with temporal convolutions.

    This implementation can be used for all use-cases.

    """

    def __init__(self, num_joints_in, in_features, num_joints_out,

                 filter_widths, causal=False, dropout=0.25, channels=1024, dense=False, sagittal=0, freezing=0, fusion=0):
        """

        Initialize this model.



        Arguments:

        num_joints_in -- number of input joints (e.g. 17 for Human3.6M)

        in_features -- number of input features for each joint (typically 2 for 2D input)

        num_joints_out -- number of output joints (can be different than input)

        filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field

        causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)

        dropout -- dropout probability

        channels -- number of convolution channels

        dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment)

        """
        super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels, sagittal=sagittal, freezing=freezing, fusion=fusion)

        self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False)

        layers_conv = []
        layers_bn = []

        self.causal_shift = [ (filter_widths[0]) // 2 if causal else 0 ]
        next_dilation = filter_widths[0]
        for i in range(1, len(filter_widths)):
            self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
            print(self.pad)
            print(next_dilation)
            self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0)
            print("shift", self.causal_shift)
            layers_conv.append(nn.Conv1d(channels, channels,
                                         filter_widths[i] if not dense else (2*self.pad[-1] + 1),
                                         dilation=next_dilation if not dense else 1,
                                         bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
            layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
            layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))

            next_dilation *= filter_widths[i]
        
        self.layers_conv = nn.ModuleList(layers_conv)
        self.layers_bn = nn.ModuleList(layers_bn)

    def set_freezing(self, freezing, expand_bn, expand_bn2, shrink):
        for param in self.shrink.parameters():
            param.requires_grad = False
        for param in self.expand_bn.parameters():
            param.requires_grad = False
        if freezing > 0:
        # Freeze initial layers (expand_conv) and potentially the first n layers_conv
            for param in self.expand_conv.parameters():
                param.requires_grad = False
            for i in range(freezing):
                print(i)
                print("i%2 = ", i%2)
                if (i%2 == 0):
                    for param in self.layers_conv[i].parameters():
                        param.requires_grad = False  # Freeze Conv1d layer
                for param in self.layers_bn[i].parameters():
                    param.requires_grad = False  # Freeze BatchNorm layer (optional)
        
    def _forward_blocks(self, x):

        #First block of Conv1D + BatchNorm1D, Relu, Dropout
        x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))

        #Repeaterd block of Conv1D + BatchNorm1D, Relu, Dropout
        for i in range(len(self.pad) - 1): #same length as filter_widths
            pad = self.pad[i+1]
            # print(pad)
            shift = self.causal_shift[i+1]
            res = x[:, :, pad + shift : x.shape[2] - pad + shift]
            # print("i voor eerste x = ", i)
            x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x))))
            # print("i voor tweede x = ", i)
            x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x))))

        x = self.shrink(x)  #1D convolution
        x = self.flatten(x)
        return x