Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import math | |
import numpy as np | |
from lipreading.models.resnet import ResNet, BasicBlock | |
from lipreading.models.resnet1D import ResNet1D, BasicBlock1D | |
from lipreading.models.shufflenetv2 import ShuffleNetV2 | |
from lipreading.models.tcn import MultibranchTemporalConvNet, TemporalConvNet | |
# -- auxiliary functions | |
def threeD_to_2D_tensor(x): | |
n_batch, n_channels, s_time, sx, sy = x.shape | |
x = x.transpose(1, 2) | |
return x.reshape(n_batch*s_time, n_channels, sx, sy) | |
def _average_batch(x, lengths, B): | |
return torch.stack( [torch.mean( x[index][:,0:i], 1 ) for index, i in enumerate(lengths)],0 ) | |
class MultiscaleMultibranchTCN(nn.Module): | |
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): | |
super(MultiscaleMultibranchTCN, self).__init__() | |
self.kernel_sizes = tcn_options['kernel_size'] | |
self.num_kernels = len( self.kernel_sizes ) | |
self.mb_ms_tcn = MultibranchTemporalConvNet(input_size, num_channels, tcn_options, dropout=dropout, relu_type=relu_type, dwpw=dwpw) | |
self.tcn_output = nn.Linear(num_channels[-1], num_classes) | |
self.consensus_func = _average_batch | |
def forward(self, x, lengths, B): | |
# x needs to have dimension (N, C, L) in order to be passed into CNN | |
xtrans = x.transpose(1, 2) | |
out = self.mb_ms_tcn(xtrans) | |
out = self.consensus_func( out, lengths, B ) | |
return self.tcn_output(out) | |
class TCN(nn.Module): | |
"""Implements Temporal Convolutional Network (TCN) | |
__https://arxiv.org/pdf/1803.01271.pdf | |
""" | |
def __init__(self, input_size, num_channels, num_classes, tcn_options, dropout, relu_type, dwpw=False): | |
super(TCN, self).__init__() | |
self.tcn_trunk = TemporalConvNet(input_size, num_channels, dropout=dropout, tcn_options=tcn_options, relu_type=relu_type, dwpw=dwpw) | |
self.tcn_output = nn.Linear(num_channels[-1], num_classes) | |
self.consensus_func = _average_batch | |
self.has_aux_losses = False | |
def forward(self, x, lengths, B): | |
# x needs to have dimension (N, C, L) in order to be passed into CNN | |
x = self.tcn_trunk(x.transpose(1, 2)) | |
x = self.consensus_func( x, lengths, B ) | |
return self.tcn_output(x) | |
class Lipreading(nn.Module): | |
def __init__( self, modality='video', hidden_dim=256, backbone_type='resnet', num_classes=30, | |
relu_type='prelu', tcn_options={}, width_mult=1.0, extract_feats=False): | |
super(Lipreading, self).__init__() | |
self.extract_feats = extract_feats | |
self.backbone_type = backbone_type | |
self.modality = modality | |
if self.modality == 'raw_audio': | |
self.frontend_nout = 1 | |
self.backend_out = 512 | |
self.trunk = ResNet1D(BasicBlock1D, [2, 2, 2, 2], relu_type=relu_type) | |
elif self.modality == 'video': | |
if self.backbone_type == 'resnet': | |
self.frontend_nout = 64 | |
self.backend_out = 512 | |
self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) | |
elif self.backbone_type == 'shufflenet': | |
assert width_mult in [0.5, 1.0, 1.5, 2.0], "Width multiplier not correct" | |
shufflenet = ShuffleNetV2( input_size=96, width_mult=width_mult) | |
self.trunk = nn.Sequential( shufflenet.features, shufflenet.conv_last, shufflenet.globalpool) | |
self.frontend_nout = 24 | |
self.backend_out = 1024 if width_mult != 2.0 else 2048 | |
self.stage_out_channels = shufflenet.stage_out_channels[-1] | |
frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU() | |
self.frontend3D = nn.Sequential( | |
nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), | |
nn.BatchNorm3d(self.frontend_nout), | |
frontend_relu, | |
nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) | |
else: | |
raise NotImplementedError | |
tcn_class = TCN if len(tcn_options['kernel_size']) == 1 else MultiscaleMultibranchTCN | |
self.tcn = tcn_class( input_size=self.backend_out, | |
num_channels=[hidden_dim*len(tcn_options['kernel_size'])*tcn_options['width_mult']]*tcn_options['num_layers'], | |
num_classes=num_classes, | |
tcn_options=tcn_options, | |
dropout=tcn_options['dropout'], | |
relu_type=relu_type, | |
dwpw=tcn_options['dwpw'], | |
) | |
# -- initialize | |
self._initialize_weights_randomly() | |
def forward(self, x, lengths): | |
if self.modality == 'video': | |
B, C, T, H, W = x.size() | |
x = self.frontend3D(x) | |
Tnew = x.shape[2] # output should be B x C2 x Tnew x H x W | |
x = threeD_to_2D_tensor( x ) | |
x = self.trunk(x) | |
if self.backbone_type == 'shufflenet': | |
x = x.view(-1, self.stage_out_channels) | |
x = x.view(B, Tnew, x.size(1)) | |
elif self.modality == 'raw_audio': | |
B, C, T = x.size() | |
x = self.trunk(x) | |
x = x.transpose(1, 2) | |
lengths = [_//640 for _ in lengths] | |
return x if self.extract_feats else self.tcn(x, lengths, B) | |
def _initialize_weights_randomly(self): | |
use_sqrt = True | |
if use_sqrt: | |
def f(n): | |
return math.sqrt( 2.0/float(n) ) | |
else: | |
def f(n): | |
return 2.0/float(n) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): | |
n = np.prod( m.kernel_size ) * m.out_channels | |
m.weight.data.normal_(0, f(n)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Linear): | |
n = float(m.weight.data[0].nelement()) | |
m.weight.data = m.weight.data.normal_(0, f(n)) | |