aiface's picture
Upload 11 files
907b7f3
raw
history blame
6.46 kB
import torch
import torch.nn as nn
import math
import numpy as np
from lipreading.models.resnet import ResNet, BasicBlock
from lipreading.models.resnet1D import ResNet1D, BasicBlock1D
from lipreading.models.shufflenetv2 import ShuffleNetV2
from lipreading.models.tcn import MultibranchTemporalConvNet, TemporalConvNet
# -- auxiliary functions
def threeD_to_2D_tensor(x):
n_batch, n_channels, s_time, sx, sy = x.shape
x = x.transpose(1, 2)
return x.reshape(n_batch*s_time, n_channels, sx, sy)
def _average_batch(x, lengths, B):
return torch.stack( [torch.mean( x[index][:,0:i], 1 ) for index, i in enumerate(lengths)],0 )
class MultiscaleMultibranchTCN(nn.Module):
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False):
super(MultiscaleMultibranchTCN, self).__init__()
self.kernel_sizes = tcn_options['kernel_size']
self.num_kernels = len( self.kernel_sizes )
self.mb_ms_tcn = MultibranchTemporalConvNet(input_size, num_channels, tcn_options, dropout=dropout, relu_type=relu_type, dwpw=dwpw)
self.tcn_output = nn.Linear(num_channels[-1], num_classes)
self.consensus_func = _average_batch
def forward(self, x, lengths, B):
# x needs to have dimension (N, C, L) in order to be passed into CNN
xtrans = x.transpose(1, 2)
out = self.mb_ms_tcn(xtrans)
out = self.consensus_func( out, lengths, B )
return self.tcn_output(out)
class TCN(nn.Module):
"""Implements Temporal Convolutional Network (TCN)
__https://arxiv.org/pdf/1803.01271.pdf
"""
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False):
super(TCN, self).__init__()
self.tcn_trunk = TemporalConvNet(input_size, num_channels, dropout=dropout, tcn_options=tcn_options, relu_type=relu_type, dwpw=dwpw)
self.tcn_output = nn.Linear(num_channels[-1], num_classes)
self.consensus_func = _average_batch
self.has_aux_losses = False
def forward(self, x, lengths, B):
# x needs to have dimension (N, C, L) in order to be passed into CNN
x = self.tcn_trunk(x.transpose(1, 2))
x = self.consensus_func( x, lengths, B )
return self.tcn_output(x)
class Lipreading(nn.Module):
def __init__( self, modality='video', hidden_dim=256, backbone_type='resnet', num_classes=30,
relu_type='prelu', tcn_options={}, width_mult=1.0, extract_feats=False):
super(Lipreading, self).__init__()
self.extract_feats = extract_feats
self.backbone_type = backbone_type
self.modality = modality
if self.modality == 'raw_audio':
self.frontend_nout = 1
self.backend_out = 512
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type)
elif self.modality == 'video':
if self.backbone_type == 'resnet':
self.frontend_nout = 64
self.backend_out = 512
self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type)
elif self.backbone_type == 'shufflenet':
assert width_mult in [0.5, 1.0, 1.5, 2.0], "Width multiplier not correct"
shufflenet = ShuffleNetV2( input_size=96, width_mult=width_mult)
self.trunk = nn.Sequential( shufflenet.features, shufflenet.conv_last, shufflenet.globalpool)
self.frontend_nout = 24
self.backend_out = 1024 if width_mult != 2.0 else 2048
self.stage_out_channels = shufflenet.stage_out_channels[-1]
frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU()
self.frontend3D = nn.Sequential(
nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False),
nn.BatchNorm3d(self.frontend_nout),
frontend_relu,
nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
else:
raise NotImplementedError
tcn_class = TCN if len(tcn_options['kernel_size']) == 1 else MultiscaleMultibranchTCN
self.tcn = tcn_class( input_size=self.backend_out,
num_channels=[hidden_dim*len(tcn_options['kernel_size'])*tcn_options['width_mult']]*tcn_options['num_layers'],
num_classes=num_classes,
tcn_options=tcn_options,
dropout=tcn_options['dropout'],
relu_type=relu_type,
dwpw=tcn_options['dwpw'],
)
# -- initialize
self._initialize_weights_randomly()
def forward(self, x, lengths):
if self.modality == 'video':
B, C, T, H, W = x.size()
x = self.frontend3D(x)
Tnew = x.shape[2] # output should be B x C2 x Tnew x H x W
x = threeD_to_2D_tensor( x )
x = self.trunk(x)
if self.backbone_type == 'shufflenet':
x = x.view(-1, self.stage_out_channels)
x = x.view(B, Tnew, x.size(1))
elif self.modality == 'raw_audio':
B, C, T = x.size()
x = self.trunk(x)
x = x.transpose(1, 2)
lengths = [_//640 for _ in lengths]
return x if self.extract_feats else self.tcn(x, lengths, B)
def _initialize_weights_randomly(self):
use_sqrt = True
if use_sqrt:
def f(n):
return math.sqrt( 2.0/float(n) )
else:
def f(n):
return 2.0/float(n)
for m in self.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
n = np.prod( m.kernel_size ) * m.out_channels
m.weight.data.normal_(0, f(n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = float(m.weight.data[0].nelement())
m.weight.data = m.weight.data.normal_(0, f(n))