TCN_UL_acitivity / tryout_v3_model.py
liesdillen's picture
Upload 4 files
da7cd93 verified
raw
history blame
7.41 kB
import torch.nn as nn
import torch.nn.functional as F
class TemporalModelBase(nn.Module):
"""
Do not instantiate this class.
"""
def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal, dropout, channels, sagittal=0, freezing=0, fusion=0):
super().__init__()
# Validate input
for fw in filter_widths:
assert fw % 2 != 0, 'Only odd filter widths are supported'
self.num_joints_in = num_joints_in
self.in_features = in_features
self.num_joints_out = 2
self.filter_widths = filter_widths
# Initialize layers
self.drop = nn.Dropout(dropout) #for regularization
self.relu = nn.ReLU(inplace=True) #introducing non-linearity
self.pad = [ filter_widths[0] // 2 ]
self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) #normalize channels
self.expand_bn2 = nn.BatchNorm2d(in_features, momentum=0.1) #normalize features
self.shrink = nn.Conv1d(channels, num_joints_out, 1)
self.flatten = nn.Flatten()
def set_bn_momentum(self, momentum):
"""
Batch normalization is a technique used to normalize the inputs of each layer, which helps stabilize and speed up the training process. The momentum parameter determines how much of the statistics from the current batch should contribute to the running mean and variance of the batch normalization layer.
In some cases, especially during fine-tuning or transfer learning, it might be beneficial to adjust the momentum dynamically. For example, when fine-tuning on a new dataset, you may want to decrease the momentum to adapt faster to the new data distribution. This method provides a way to adjust the momentum value dynamically during training.
"""
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum
def receptive_field(self):
"""
Return the total receptive field of this model as # of frames.
The receptive field represents the region in the input data that influences a particular output of the neural network. It is determined by the filter sizes and strides of the convolutional layers in the network.
"""
frames = 0
for f in self.pad:
frames += f
return 1 + 2*frames
def total_causal_shift(self):
"""
Return the asymmetric offset for sequence padding.
The returned value is typically 0 if causal convolutions are disabled,
otherwise it is half the receptive field.
"""
frames = self.causal_shift[0]
next_dilation = self.filter_widths[0]
for i in range(1, len(self.filter_widths)):
frames += self.causal_shift[i] * next_dilation
next_dilation *= self.filter_widths[i]
return frames
def forward(self, x):
x = self._forward_blocks(x)
return x
class TemporalModel(TemporalModelBase):
"""
Reference 3D pose estimation model with temporal convolutions.
This implementation can be used for all use-cases.
"""
def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal=False, dropout=0.25, channels=1024, dense=False, sagittal=0, freezing=0, fusion=0):
"""
Initialize this model.
Arguments:
num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
in_features -- number of input features for each joint (typically 2 for 2D input)
num_joints_out -- number of output joints (can be different than input)
filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
dropout -- dropout probability
channels -- number of convolution channels
dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment)
"""
super().__init__(num_joints_in, in_features, num_joints_out, filter_widths, causal, dropout, channels, sagittal=sagittal, freezing=freezing, fusion=fusion)
self.expand_conv = nn.Conv1d(num_joints_in*in_features, channels, filter_widths[0], bias=False)
layers_conv = []
layers_bn = []
self.causal_shift = [ (filter_widths[0]) // 2 if causal else 0 ]
next_dilation = filter_widths[0]
for i in range(1, len(filter_widths)):
self.pad.append((filter_widths[i] - 1)*next_dilation // 2)
print(self.pad)
print(next_dilation)
self.causal_shift.append((filter_widths[i]//2 * next_dilation) if causal else 0)
print("shift", self.causal_shift)
layers_conv.append(nn.Conv1d(channels, channels,
filter_widths[i] if not dense else (2*self.pad[-1] + 1),
dilation=next_dilation if not dense else 1,
bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
layers_conv.append(nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
next_dilation *= filter_widths[i]
self.layers_conv = nn.ModuleList(layers_conv)
self.layers_bn = nn.ModuleList(layers_bn)
def set_freezing(self, freezing, expand_bn, expand_bn2, shrink):
for param in self.shrink.parameters():
param.requires_grad = False
for param in self.expand_bn.parameters():
param.requires_grad = False
if freezing > 0:
# Freeze initial layers (expand_conv) and potentially the first n layers_conv
for param in self.expand_conv.parameters():
param.requires_grad = False
for i in range(freezing):
print(i)
print("i%2 = ", i%2)
if (i%2 == 0):
for param in self.layers_conv[i].parameters():
param.requires_grad = False # Freeze Conv1d layer
for param in self.layers_bn[i].parameters():
param.requires_grad = False # Freeze BatchNorm layer (optional)
def _forward_blocks(self, x):
#First block of Conv1D + BatchNorm1D, Relu, Dropout
x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
#Repeaterd block of Conv1D + BatchNorm1D, Relu, Dropout
for i in range(len(self.pad) - 1): #same length as filter_widths
pad = self.pad[i+1]
# print(pad)
shift = self.causal_shift[i+1]
res = x[:, :, pad + shift : x.shape[2] - pad + shift]
# print("i voor eerste x = ", i)
x = self.drop(self.relu(self.layers_bn[2*i](self.layers_conv[2*i](x))))
# print("i voor tweede x = ", i)
x = res + self.drop(self.relu(self.layers_bn[2*i + 1](self.layers_conv[2*i + 1](x))))
x = self.shrink(x) #1D convolution
x = self.flatten(x)
return x