| | import math |
| | import time |
| | from collections import OrderedDict |
| | from typing import Dict, List, Optional, Tuple |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from espnet2.torch_utils.get_layer_from_string import get_layer |
| | from torch.nn import init |
| | from torch.nn.parameter import Parameter |
| | import src.utils as utils |
| |
|
| |
|
| | class Lambda(nn.Module): |
| | def __init__(self, lambd): |
| | super().__init__() |
| | import types |
| |
|
| | assert type(lambd) is types.LambdaType |
| | self.lambd = lambd |
| |
|
| | def forward(self, x): |
| | return self.lambd(x) |
| |
|
| |
|
| | class LayerNormPermuted(nn.LayerNorm): |
| | def __init__(self, *args, **kwargs): |
| | super(LayerNormPermuted, self).__init__(*args, **kwargs) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: [B, C, T, F] |
| | """ |
| | x = x.permute(0, 2, 3, 1) |
| | x = super().forward(x) |
| | x = x.permute(0, 3, 1, 2) |
| | return x |
| |
|
| |
|
| | |
| | class LayerNormalization4D(nn.Module): |
| | def __init__(self, C, eps=1e-5, preserve_outdim=False): |
| | super().__init__() |
| | self.norm = nn.LayerNorm(C, eps=eps) |
| | self.preserve_outdim = preserve_outdim |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | input: (*, C) |
| | """ |
| | x = self.norm(x) |
| | return x |
| |
|
| |
|
| | class LayerNormalization4DCF(nn.Module): |
| | def __init__(self, input_dimension, eps=1e-5): |
| | assert len(input_dimension) == 2 |
| | Q, C = input_dimension |
| | super().__init__() |
| | self.norm = nn.LayerNorm((Q * C), eps=eps) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """ |
| | input: (B, T, Q * C) |
| | """ |
| | x = self.norm(x) |
| |
|
| | return x |
| |
|
| |
|
| | class LayerNormalization4D_old(nn.Module): |
| | def __init__(self, input_dimension, eps=1e-5): |
| | super().__init__() |
| | param_size = [1, input_dimension, 1, 1] |
| | self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
| | self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) |
| | init.ones_(self.gamma) |
| | init.zeros_(self.beta) |
| | self.eps = eps |
| |
|
| | def forward(self, x): |
| | if x.ndim == 4: |
| | _, C, _, _ = x.shape |
| | stat_dim = (1,) |
| | else: |
| | raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) |
| | mu_ = x.mean(dim=stat_dim, keepdim=True) |
| | std_ = torch.sqrt(x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps) |
| | x_hat = ((x - mu_) / std_) * self.gamma + self.beta |
| | return x_hat |
| |
|
| |
|
| | def mod_pad(x, chunk_size, pad): |
| | |
| | |
| | mod = 0 |
| | if (x.shape[-1] % chunk_size) != 0: |
| | mod = chunk_size - (x.shape[-1] % chunk_size) |
| |
|
| | x = F.pad(x, (0, mod)) |
| | x = F.pad(x, pad) |
| |
|
| | return x, mod |
| |
|
| |
|
| | class Attention_STFT_causal(nn.Module): |
| | def __getitem__(self, key): |
| | return getattr(self, key) |
| |
|
| | def __init__( |
| | self, |
| | emb_dim, |
| | n_freqs, |
| | approx_qk_dim=512, |
| | n_head=4, |
| | activation="prelu", |
| | eps=1e-5, |
| | skip_conn=True, |
| | use_flash_attention=False, |
| | dim_feedforward=-1, |
| | ): |
| | super().__init__() |
| | self.position_code = utils.PositionalEncoding(emb_dim * n_freqs, max_len=5000) |
| |
|
| | self.skip_conn = skip_conn |
| | self.n_freqs = n_freqs |
| | self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
| | self.n_head = n_head |
| | self.V_dim = emb_dim // n_head |
| | self.emb_dim = emb_dim |
| | assert emb_dim % n_head == 0 |
| | E = self.E |
| |
|
| | self.add_module( |
| | "attn_conv_Q", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, E * n_head), |
| | get_layer(activation)(), |
| | |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
| | ), |
| | LayerNormalization4DCF((n_freqs, E), eps=eps), |
| | ), |
| | ) |
| | self.add_module( |
| | "attn_conv_K", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, E * n_head), |
| | get_layer(activation)(), |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, E) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * E) |
| | ), |
| | LayerNormalization4DCF((n_freqs, E), eps=eps), |
| | ), |
| | ) |
| | self.add_module( |
| | "attn_conv_V", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, (emb_dim // n_head) * n_head), |
| | get_layer(activation)(), |
| | Lambda( |
| | lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2], n_head, (emb_dim // n_head)) |
| | .permute(0, 3, 1, 2, 4) |
| | .reshape(x.shape[0] * n_head, x.shape[1], x.shape[2] * (emb_dim // n_head)) |
| | ), |
| | LayerNormalization4DCF((n_freqs, emb_dim // n_head), eps=eps), |
| | ), |
| | ) |
| |
|
| | self.dim_feedforward = dim_feedforward |
| |
|
| | if dim_feedforward == -1: |
| | self.add_module( |
| | "attn_concat_proj", |
| | nn.Sequential( |
| | nn.Linear(emb_dim, emb_dim), |
| | get_layer(activation)(), |
| | Lambda(lambda x: x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])), |
| | LayerNormalization4DCF((n_freqs, emb_dim), eps=eps), |
| | ), |
| | ) |
| | else: |
| | self.linear1 = nn.Linear(emb_dim, dim_feedforward) |
| | self.dropout = nn.Dropout(p=0.1) |
| | self.activation = nn.ReLU() |
| | self.linear2 = nn.Linear(dim_feedforward, emb_dim) |
| | self.dropout2 = nn.Dropout(p=0.1) |
| | self.norm = LayerNormalization4DCF((n_freqs, emb_dim), eps=eps) |
| |
|
| | def _ff_block(self, x): |
| | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
| | return self.dropout2(x) |
| |
|
| | def get_lookahead_mask(self, seq_len, device): |
| | """Creates a binary mask for each sequence which masks future frames. |
| | Arguments |
| | --------- |
| | seq_len: int |
| | Length of the sequence. |
| | device: torch.device |
| | The device on which to create the mask. |
| | Example |
| | ------- |
| | >>> a = torch.LongTensor([[1,1,0], [2,3,0], [4,5,0]]) |
| | >>> get_lookahead_mask(a.shape[1], device) |
| | tensor([[0., -inf, -inf], |
| | [0., 0., -inf], |
| | [0., 0., 0.]]) |
| | """ |
| | mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1) |
| |
|
| | return mask.detach().to(device) |
| |
|
| | def forward(self, batch): |
| | |
| | |
| | inputs = batch |
| | B0, T0, Q0, C0 = batch.shape |
| | |
| | |
| |
|
| | |
| | pos_code = self.position_code(batch) |
| | |
| | _, T, QC = pos_code.shape |
| | pos_code = pos_code.reshape(1, T, Q0, C0) |
| | batch = batch + pos_code |
| |
|
| | |
| | |
| |
|
| | Q = self["attn_conv_Q"](batch) |
| | K = self["attn_conv_K"](batch) |
| | V = self["attn_conv_V"](batch) |
| |
|
| | emb_dim = Q.shape[-1] |
| |
|
| | local_mask = self.get_lookahead_mask(batch.shape[1], batch.device) |
| |
|
| | attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) |
| | attn_mat.masked_fill_(local_mask == 0, -float("Inf")) |
| | attn_mat = F.softmax(attn_mat, dim=2) |
| |
|
| | V = torch.matmul(attn_mat, V) |
| | V = V.reshape(-1, T0, V.shape[-1]) |
| | V = V.transpose(1, 2) |
| |
|
| | batch = V.reshape(B0, self.n_head, self.n_freqs, self.V_dim, T0) |
| | batch = batch.transpose(2, 3) |
| | batch = batch.reshape(B0, self.n_head * self.V_dim, self.n_freqs, T0) |
| | batch = batch.permute(0, 3, 2, 1) |
| |
|
| | if self.dim_feedforward == -1: |
| | batch = self["attn_concat_proj"](batch) |
| | else: |
| | batch = batch + self._ff_block(batch) |
| | batch = batch.reshape(batch.shape[0], batch.shape[1], batch.shape[2] * batch.shape[3]) |
| | batch = self.norm(batch) |
| | batch = batch.reshape(batch.shape[0], batch.shape[1], Q0, C0) |
| |
|
| | |
| | |
| |
|
| | |
| | if self.skip_conn: |
| | return batch + inputs |
| | else: |
| | return batch |
| |
|
| |
|
| | class GridNetBlock(nn.Module): |
| | def __getitem__(self, key): |
| | return getattr(self, key) |
| |
|
| | def __init__( |
| | self, |
| | emb_dim, |
| | emb_ks, |
| | emb_hs, |
| | n_freqs, |
| | hidden_channels, |
| | n_head=4, |
| | approx_qk_dim=512, |
| | activation="prelu", |
| | eps=1e-5, |
| | pool="mean", |
| | use_attention=False, |
| | ): |
| | super().__init__() |
| | bidirectional = False |
| |
|
| | self.global_atten_causal = True |
| |
|
| | self.pool = pool |
| |
|
| | self.E = math.ceil(approx_qk_dim * 1.0 / n_freqs) |
| |
|
| | self.V_dim = emb_dim // n_head |
| | self.H = hidden_channels |
| | in_channels = emb_dim * emb_ks |
| | self.in_channels = in_channels |
| | self.n_freqs = n_freqs |
| |
|
| | |
| | self.intra_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
| | self.intra_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=True) |
| | self.intra_linear = nn.ConvTranspose1d(hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs) |
| | self.emb_dim = emb_dim |
| | self.emb_ks = emb_ks |
| | self.emb_hs = emb_hs |
| |
|
| | |
| | self.inter_norm = LayerNormalization4D_old(emb_dim, eps=eps) |
| | self.inter_rnn = nn.LSTM(in_channels, hidden_channels, 1, batch_first=True, bidirectional=bidirectional) |
| | self.inter_linear = nn.ConvTranspose1d(hidden_channels * (bidirectional + 1), emb_dim, emb_ks, stride=emb_hs) |
| |
|
| | |
| | self.use_attention = use_attention |
| |
|
| | if self.use_attention: |
| | self.pool_atten_causal = Attention_STFT_causal( |
| | emb_dim=emb_dim, |
| | n_freqs=n_freqs, |
| | approx_qk_dim=approx_qk_dim, |
| | n_head=n_head, |
| | activation=activation, |
| | eps=eps, |
| | ) |
| |
|
| | def init_buffers(self, batch_size, device): |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def forward(self, x, init_state=None): |
| | """GridNetBlock Forward. |
| | |
| | Args: |
| | x: [B, C, T, Q] |
| | out: [B, C, T, Q] |
| | """ |
| | B, C, old_T, old_Q = x.shape |
| | |
| | |
| | |
| | |
| | T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
| | Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks |
| | x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) |
| |
|
| | |
| | |
| | input_ = x |
| | intra_rnn = self.intra_norm(input_) |
| | intra_rnn = intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) |
| |
|
| | intra_rnn = torch.split(intra_rnn, self.emb_ks, dim=-1) |
| | intra_rnn = torch.stack(intra_rnn, dim=0) |
| | intra_rnn = intra_rnn.permute(1, 2, 3, 0).flatten(1, 2) |
| | intra_rnn = intra_rnn.transpose(1, 2) |
| | self.intra_rnn.flatten_parameters() |
| |
|
| | |
| | intra_rnn, _ = self.intra_rnn(intra_rnn) |
| | intra_rnn = intra_rnn.transpose(1, 2) |
| | intra_rnn = self.intra_linear(intra_rnn) |
| | intra_rnn = intra_rnn.view([B, T, C, Q]) |
| | intra_rnn = intra_rnn.transpose(1, 2).contiguous() |
| | intra_rnn = intra_rnn + input_ |
| | intra_rnn = intra_rnn[:, :, :, :old_Q] |
| | Q = old_Q |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | input_ = intra_rnn |
| |
|
| | inter_rnn = self.inter_norm(intra_rnn) |
| | inter_rnn = inter_rnn.transpose(1, 3).reshape(B * Q, T, C) |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | self.inter_rnn.flatten_parameters() |
| | |
| | |
| | inter_rnn, _ = self.inter_rnn(inter_rnn) |
| | inter_rnn = inter_rnn.transpose(1, 2) |
| | inter_rnn = self.inter_linear(inter_rnn) |
| |
|
| | _, new_C, new_T = inter_rnn.shape |
| | inter_rnn = inter_rnn.reshape(B, Q, new_C, new_T) |
| | inter_rnn = inter_rnn.permute(0, 2, 3, 1) |
| | |
| | |
| | inter_rnn = inter_rnn + input_ |
| | |
| |
|
| | |
| |
|
| | |
| | if self.use_attention: |
| | out = inter_rnn |
| |
|
| | inter_rnn = inter_rnn.permute(0, 2, 3, 1) |
| | inter_rnn = self.pool_atten_causal(inter_rnn) |
| | inter_rnn = inter_rnn.permute(0, 3, 1, 2) |
| | inter_rnn = out + inter_rnn |
| |
|
| | |
| | |
| | |
| | inter_rnn = inter_rnn[..., :old_T, :] |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | return inter_rnn, init_state |
| |
|