Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch.nn.functional as F | |
class TemporalModelBase(nn.Module): | |
""" | |
Do not instantiate this class. | |
""" | |
def __init__(self, num_joints_in, in_features, num_joints_out, | |
filter_widths, causal, dropout, channels, sagittal=0, freezing=0, fusion=0): | |
super().__init__() | |
# Validate input | |
for fw in filter_widths: | |
assert fw % 2 != 0, 'Only odd filter widths are supported' | |
self.num_joints_in = num_joints_in | |
self.in_features = in_features | |
self.num_joints_out = 2 | |
self.filter_widths = filter_widths | |
# Initialize layers | |
self.drop = nn.Dropout(dropout) #for regularization | |
self.relu = nn.ReLU(inplace=True) #introducing non-linearity | |
self.pad = [ filter_widths[0] // 2 ] | |
self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) #normalize channels | |
self.expand_bn2 = nn.BatchNorm2d(in_features, momentum=0.1) #normalize features | |
self.shrink = nn.Conv1d(channels, num_joints_out, 1) | |
self.flatten = nn.Flatten() | |
def set_bn_momentum(self, momentum): | |
""" | |
Batch normalization is a technique used to normalize the inputs of each layer, which helps stabilize and speed up the training process. The momentum parameter determines how much of the statistics from the current batch should contribute to the running mean and variance of the batch normalization layer. | |
In some cases, especially during fine-tuning or transfer learning, it might be beneficial to adjust the momentum dynamically. For example, when fine-tuning on a new dataset, you may want to decrease the momentum to adapt faster to the new data distribution. This method provides a way to adjust the momentum value dynamically during training. | |
""" | |
self.expand_bn.momentum = momentum | |
for bn in self.layers_bn: | |
bn.momentum = momentum | |
def receptive_field(self): | |
""" | |
Return the total receptive field of this model as # of frames. | |
The receptive field represents the region in the input data that influences a particular output of the neural network. It is determined by the filter sizes and strides of the convolutional layers in the network. | |
""" | |
frames = 0 | |
for f in self.pad: | |
frames += f | |
return 1 + 2*frames | |
def total_causal_shift(self): | |
""" | |
Return the asymmetric offset for sequence padding. | |
The returned value is typically 0 if causal convolutions are disabled, | |
otherwise it is half the receptive field. | |
""" | |
frames = self.causal_shift[0] | |
next_dilation = self.filter_widths[0] | |
for i in range(1, len(self.filter_widths)): | |
frames += self.causal_shift[i] * next_dilation | |
next_dilation *= self.filter_widths[i] | |
return frames | |
def forward(self, x): | |
x = self._forward_blocks(x) | |
return x | |
class TemporalModel(TemporalModelBase): | |
""" | |
Reference 3D pose estimation model with temporal convolutions. | |
This implementation can be used for all use-cases. | |
""" | |
def __init__(self, num_joints_in, in_features, num_joints_out, | |
filter_widths, causal=False, dropout=0.25, channels=1024, dense=False, sagittal=0, freezing=0, fusion=0): | |
""" | |
Initialize this model. | |
Arguments: | |
num_joints_in -- number of input joints (e.g. 17 for Human3.6M) | |
in_features -- number of input features for each joint (typically 2 for 2D input) | |
num_joints_out -- number of output joints (can be different than input) | |
filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field | |
causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) | |
dropout -- dropout probability | |
channels -- number of convolution channels | |
dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) | |
""" | |
super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels, sagittal=sagittal, freezing=freezing, fusion=fusion) | |
self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False) | |
layers_conv = [] | |
layers_bn = [] | |
self.causal_shift = [ (filter_widths[0]) // 2 if causal else 0 ] | |
next_dilation = filter_widths[0] | |
for i in range(1, len(filter_widths)): | |
self.pad.append((filter_widths[i] - 1)*next_dilation // 2) | |
print(self.pad) | |
print(next_dilation) | |
self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0) | |
print("shift", self.causal_shift) | |
layers_conv.append(nn.Conv1d(channels, channels, | |
filter_widths[i] if not dense else (2*self.pad[-1] + 1), | |
dilation=next_dilation if not dense else 1, | |
bias=False)) | |
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) | |
layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) | |
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) | |
next_dilation *= filter_widths[i] | |
self.layers_conv = nn.ModuleList(layers_conv) | |
self.layers_bn = nn.ModuleList(layers_bn) | |
def set_freezing(self, freezing, expand_bn, expand_bn2, shrink): | |
for param in self.shrink.parameters(): | |
param.requires_grad = False | |
for param in self.expand_bn.parameters(): | |
param.requires_grad = False | |
if freezing > 0: | |
# Freeze initial layers (expand_conv) and potentially the first n layers_conv | |
for param in self.expand_conv.parameters(): | |
param.requires_grad = False | |
for i in range(freezing): | |
print(i) | |
print("i%2 = ", i%2) | |
if (i%2 == 0): | |
for param in self.layers_conv[i].parameters(): | |
param.requires_grad = False # Freeze Conv1d layer | |
for param in self.layers_bn[i].parameters(): | |
param.requires_grad = False # Freeze BatchNorm layer (optional) | |
def _forward_blocks(self, x): | |
#First block of Conv1D + BatchNorm1D, Relu, Dropout | |
x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) | |
#Repeaterd block of Conv1D + BatchNorm1D, Relu, Dropout | |
for i in range(len(self.pad) - 1): #same length as filter_widths | |
pad = self.pad[i+1] | |
# print(pad) | |
shift = self.causal_shift[i+1] | |
res = x[:, :, pad + shift : x.shape[2] - pad + shift] | |
# print("i voor eerste x = ", i) | |
x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x)))) | |
# print("i voor tweede x = ", i) | |
x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x)))) | |
x = self.shrink(x) #1D convolution | |
x = self.flatten(x) | |
return x | |