import torch.nn as nn import torch.nn.functional as F class TemporalModelBase(nn.Module): """ Do not instantiate this class. """ def __init__(self, num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels, sagittal=0, freezing=0, fusion=0): super().__init__() # Validate input for fw in filter_widths: assert fw % 2 != 0, 'Only odd filter widths are supported' self.num_joints_in = num_joints_in self.in_features = in_features self.num_joints_out = 2 self.filter_widths = filter_widths # Initialize layers self.drop = nn.Dropout(dropout) #for regularization self.relu = nn.ReLU(inplace=True) #introducing non-linearity self.pad = [ filter_widths[0] // 2 ] self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) #normalize channels self.expand_bn2 = nn.BatchNorm2d(in_features, momentum=0.1) #normalize features self.shrink = nn.Conv1d(channels, num_joints_out, 1) self.flatten = nn.Flatten() def set_bn_momentum(self, momentum): """ Batch normalization is a technique used to normalize the inputs of each layer, which helps stabilize and speed up the training process. The momentum parameter determines how much of the statistics from the current batch should contribute to the running mean and variance of the batch normalization layer. In some cases, especially during fine-tuning or transfer learning, it might be beneficial to adjust the momentum dynamically. For example, when fine-tuning on a new dataset, you may want to decrease the momentum to adapt faster to the new data distribution. This method provides a way to adjust the momentum value dynamically during training. """ self.expand_bn.momentum = momentum for bn in self.layers_bn: bn.momentum = momentum def receptive_field(self): """ Return the total receptive field of this model as # of frames. The receptive field represents the region in the input data that influences a particular output of the neural network. It is determined by the filter sizes and strides of the convolutional layers in the network. """ frames = 0 for f in self.pad: frames += f return 1 + 2*frames def total_causal_shift(self): """ Return the asymmetric offset for sequence padding. The returned value is typically 0 if causal convolutions are disabled, otherwise it is half the receptive field. """ frames = self.causal_shift[0] next_dilation = self.filter_widths[0] for i in range(1, len(self.filter_widths)): frames += self.causal_shift[i] * next_dilation next_dilation *= self.filter_widths[i] return frames def forward(self, x): x = self._forward_blocks(x) return x class TemporalModel(TemporalModelBase): """ Reference 3D pose estimation model with temporal convolutions. This implementation can be used for all use-cases. """ def __init__(self, num_joints_in, in_features, num_joints_out, filter_widths, causal=False, dropout=0.25, channels=1024, dense=False, sagittal=0, freezing=0, fusion=0): """ Initialize this model. Arguments: num_joints_in -- number of input joints (e.g. 17 for Human3.6M) in_features -- number of input features for each joint (typically 2 for 2D input) num_joints_out -- number of output joints (can be different than input) filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) dropout -- dropout probability channels -- number of convolution channels dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) """ super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels, sagittal=sagittal, freezing=freezing, fusion=fusion) self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False) layers_conv = [] layers_bn = [] self.causal_shift = [ (filter_widths[0]) // 2 if causal else 0 ] next_dilation = filter_widths[0] for i in range(1, len(filter_widths)): self.pad.append((filter_widths[i] - 1)*next_dilation // 2) print(self.pad) print(next_dilation) self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0) print("shift", self.causal_shift) layers_conv.append(nn.Conv1d(channels, channels, filter_widths[i] if not dense else (2*self.pad[-1] + 1), dilation=next_dilation if not dense else 1, bias=False)) layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) next_dilation *= filter_widths[i] self.layers_conv = nn.ModuleList(layers_conv) self.layers_bn = nn.ModuleList(layers_bn) def set_freezing(self, freezing, expand_bn, expand_bn2, shrink): for param in self.shrink.parameters(): param.requires_grad = False for param in self.expand_bn.parameters(): param.requires_grad = False if freezing > 0: # Freeze initial layers (expand_conv) and potentially the first n layers_conv for param in self.expand_conv.parameters(): param.requires_grad = False for i in range(freezing): print(i) print("i%2 = ", i%2) if (i%2 == 0): for param in self.layers_conv[i].parameters(): param.requires_grad = False # Freeze Conv1d layer for param in self.layers_bn[i].parameters(): param.requires_grad = False # Freeze BatchNorm layer (optional) def _forward_blocks(self, x): #First block of Conv1D + BatchNorm1D, Relu, Dropout x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) #Repeaterd block of Conv1D + BatchNorm1D, Relu, Dropout for i in range(len(self.pad) - 1): #same length as filter_widths pad = self.pad[i+1] # print(pad) shift = self.causal_shift[i+1] res = x[:, :, pad + shift : x.shape[2] - pad + shift] # print("i voor eerste x = ", i) x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x)))) # print("i voor tweede x = ", i) x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x)))) x = self.shrink(x) #1D convolution x = self.flatten(x) return x