styletts2 / Utils /ASR /layers.py
mrfakename's picture
Initial Commit
635f007
raw history blame
No virus
13.9 kB
import math
import torch
from torch import nn
from typing import Optional, Any
from torch import Tensor
import torch.nn.functional as F
import torchaudio
import torchaudio.functional as audio_F
import random
random.seed(0)
def _get_activation_fn(activ):
if activ == "relu":
return nn.ReLU()
elif activ == "lrelu":
return nn.LeakyReLU(0.2)
elif activ == "swish":
return lambda x: x * torch.sigmoid(x)
else:
raise RuntimeError(
"Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class CausualConv(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=1,
dilation=1,
bias=True,
w_init_gain="linear",
param=None,
):
super(CausualConv, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2) * 2
else:
self.padding = padding * 2
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
)
def forward(self, x):
x = self.conv(x)
x = x[:, :, : -self.padding]
return x
class CausualBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
super(CausualBlock, self).__init__()
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
layers = [
CausualConv(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(p=dropout_p),
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class ConvBlock(nn.Module):
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
super().__init__()
self._n_groups = 8
self.blocks = nn.ModuleList(
[
self._get_conv(
hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
)
for i in range(n_conv)
]
)
def forward(self, x):
for block in self.blocks:
res = x
x = block(x)
x += res
return x
def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
layers = [
ConvNorm(
hidden_dim,
hidden_dim,
kernel_size=3,
padding=dilation,
dilation=dilation,
),
_get_activation_fn(activ),
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
nn.Dropout(p=dropout_p),
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
_get_activation_fn(activ),
nn.Dropout(p=dropout_p),
]
return nn.Sequential(*layers)
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
super(LocationLayer, self).__init__()
padding = int((attention_kernel_size - 1) / 2)
self.location_conv = ConvNorm(
2,
attention_n_filters,
kernel_size=attention_kernel_size,
padding=padding,
bias=False,
stride=1,
dilation=1,
)
self.location_dense = LinearNorm(
attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
)
def forward(self, attention_weights_cat):
processed_attention = self.location_conv(attention_weights_cat)
processed_attention = processed_attention.transpose(1, 2)
processed_attention = self.location_dense(processed_attention)
return processed_attention
class Attention(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(Attention, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float("inf")
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
alignment = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
if mask is not None:
alignment.data.masked_fill_(mask, self.score_mask_value)
attention_weights = F.softmax(alignment, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights
class ForwardAttentionV2(nn.Module):
def __init__(
self,
attention_rnn_dim,
embedding_dim,
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
):
super(ForwardAttentionV2, self).__init__()
self.query_layer = LinearNorm(
attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.memory_layer = LinearNorm(
embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
)
self.v = LinearNorm(attention_dim, 1, bias=False)
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size, attention_dim
)
self.score_mask_value = -float(1e20)
def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
"""
PARAMS
------
query: decoder output (batch, n_mel_channels * n_frames_per_step)
processed_memory: processed encoder outputs (B, T_in, attention_dim)
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
RETURNS
-------
alignment (batch, max_time)
"""
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_weights_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights + processed_memory)
)
energies = energies.squeeze(-1)
return energies
def forward(
self,
attention_hidden_state,
memory,
processed_memory,
attention_weights_cat,
mask,
log_alpha,
):
"""
PARAMS
------
attention_hidden_state: attention rnn last output
memory: encoder outputs
processed_memory: processed encoder outputs
attention_weights_cat: previous and cummulative attention weights
mask: binary mask for padded data
"""
log_energy = self.get_alignment_energies(
attention_hidden_state, processed_memory, attention_weights_cat
)
# log_energy =
if mask is not None:
log_energy.data.masked_fill_(mask, self.score_mask_value)
# attention_weights = F.softmax(alignment, dim=1)
# content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
# log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
# log_total_score = log_alpha + content_score
# previous_attention_weights = attention_weights_cat[:,0,:]
log_alpha_shift_padded = []
max_time = log_energy.size(1)
for sft in range(2):
shifted = log_alpha[:, : max_time - sft]
shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
log_alpha_new = biased + log_energy
attention_weights = F.softmax(log_alpha_new, dim=1)
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
attention_context = attention_context.squeeze(1)
return attention_context, attention_weights, log_alpha_new
class PhaseShuffle2d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle2d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = torch.cat([right, left], dim=3)
return shuffled
class PhaseShuffle1d(nn.Module):
def __init__(self, n=2):
super(PhaseShuffle1d, self).__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :move]
right = x[:, :, move:]
shuffled = torch.cat([right, left], dim=2)
return shuffled
class MFCC(nn.Module):
def __init__(self, n_mfcc=40, n_mels=80):
super(MFCC, self).__init__()
self.n_mfcc = n_mfcc
self.n_mels = n_mels
self.norm = "ortho"
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
self.register_buffer("dct_mat", dct_mat)
def forward(self, mel_specgram):
if len(mel_specgram.shape) == 2:
mel_specgram = mel_specgram.unsqueeze(0)
unsqueezed = True
else:
unsqueezed = False
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
if unsqueezed:
mfcc = mfcc.squeeze(0)
return mfcc