tgritsaev's picture
Upload 198 files
affcd23 verified
import torch
from torch import nn
from hw_asr.base import BaseModel
class RNNwBatchNorm(nn.Module):
def __init__(self, input_size, hidden_size, rnn_dropout):
super().__init__()
self.rnn = nn.GRU(input_size, hidden_size, dropout=rnn_dropout, batch_first=False, bidirectional=True)
self.norm = nn.BatchNorm1d(hidden_size)
def forward(self, x, h=None):
# N x T x input_size
x, h = self.rnn(x, h)
# T x N x (2 * hidden_size)
x = x.view(x.shape[0], x.shape[1], 2, -1).sum(2)
# T x N x hidden_size
t_dim, n_dim = x.shape[0], x.shape[1]
x = x.view((t_dim * n_dim, -1))
x = self.norm(x)
x = x.view((t_dim, n_dim, -1)).contiguous()
return x, h
# https://proceedings.mlr.press/v48/amodei16.pdf
class DeepSpeech2Model(BaseModel):
def __init__(self, n_feats, n_rnn_layers, rnn_hidden_size, rnn_dropout, n_class):
assert n_rnn_layers >= 1
super().__init__(n_feats, n_class)
self.conv = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, padding=(20, 5), kernel_size=(41, 11), stride=(2, 2)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=32, padding=(10, 5), kernel_size=(21, 11), stride=(2, 2)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=96, padding=(10, 5), kernel_size=(21, 11), stride=(2, 1)),
nn.BatchNorm2d(96),
nn.ReLU(),
)
rnn_input_size = (n_feats + 2 * 20 - 41) // 2 + 1
rnn_input_size = (rnn_input_size + 2 * 10 - 21) // 2 + 1
rnn_input_size = (rnn_input_size + 2 * 10 - 21) // 2 + 1
rnn_input_size *= 96
self.rnns = nn.Sequential(
RNNwBatchNorm(rnn_input_size, rnn_hidden_size, rnn_dropout),
*(RNNwBatchNorm(rnn_hidden_size, rnn_hidden_size, rnn_dropout) for _ in range(n_rnn_layers - 1))
)
self.fc = nn.Linear(in_features=rnn_hidden_size, out_features=n_class)
self.softmax = nn.Softmax(dim=2)
def forward(self, spectrogram, spectrogram_length, **batch):
# N x big_F x big_T
x = self.conv(spectrogram.unsqueeze(1))
# N x C x F x T
x = x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
# N x (C * F) x T
x = x.transpose(1, 2).transpose(0, 1).contiguous()
# T x N x (C * F)
h = None
for rnn in self.rnns:
x, h = rnn(x, h)
# T x N x rnn_hidden_size
t_dim, n_dim = x.shape[0], x.shape[1]
x = x.view((t_dim * n_dim, -1))
x = self.fc(x)
x = x.view((t_dim, n_dim, -1)).transpose(0, 1)
# N x T x n_class
return {"logits": x}
def transform_input_lengths(self, input_lengths):
t_dim = input_lengths.max()
t_dim = (t_dim + 2 * 5 - 11) // 2 + 1
t_dim = (t_dim + 2 * 5 - 11) // 2 + 1
t_dim = (t_dim + 2 * 5 - 11) + 1
return torch.zeros_like(input_lengths).fill_(t_dim)