import math from collections import OrderedDict from typing import Optional from torch import Tensor import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchmetrics.functional import( scale_invariant_signal_noise_ratio as si_snr, signal_noise_ratio as snr, signal_distortion_ratio as sdr, scale_invariant_signal_distortion_ratio as si_sdr) from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding def mod_pad(x, chunk_size, pad): # Mod pad the input to perform integer number of # inferences mod = 0 if (x.shape[-1] % chunk_size) != 0: mod = chunk_size - (x.shape[-1] % chunk_size) x = F.pad(x, (0, mod)) x = F.pad(x, pad) return x, mod class LayerNormPermuted(nn.LayerNorm): def __init__(self, *args, **kwargs): super(LayerNormPermuted, self).__init__(*args, **kwargs) def forward(self, x): """ Args: x: [B, C, T] """ x = x.permute(0, 2, 1) # [B, T, C] x = super().forward(x) x = x.permute(0, 2, 1) # [B, C, T] return x class DepthwiseSeparableConv(nn.Module): """ Depthwise separable convolutions """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation): super(DepthwiseSeparableConv, self).__init__() self.layers = nn.Sequential( nn.Conv1d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, dilation=dilation), LayerNormPermuted(in_channels), nn.ReLU(), nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), LayerNormPermuted(out_channels), nn.ReLU(), ) def forward(self, x): return self.layers(x) class DilatedCausalConvEncoder(nn.Module): """ A dilated causal convolution based encoder for encoding time domain audio input into latent space. """ def __init__(self, channels, num_layers, kernel_size=3): super(DilatedCausalConvEncoder, self).__init__() self.channels = channels self.num_layers = num_layers self.kernel_size = kernel_size # Compute buffer lengths for each layer # buf_length[i] = (kernel_size - 1) * dilation[i] self.buf_lengths = [(kernel_size - 1) * 2**i for i in range(num_layers)] # Compute buffer start indices for each layer self.buf_indices = [0] for i in range(num_layers - 1): self.buf_indices.append( self.buf_indices[-1] + self.buf_lengths[i]) # Dilated causal conv layers aggregate previous context to obtain # contexful encoded input. _dcc_layers = OrderedDict() for i in range(num_layers): dcc_layer = DepthwiseSeparableConv( channels, channels, kernel_size=3, stride=1, padding=0, dilation=2**i) _dcc_layers.update({'dcc_%d' % i: dcc_layer}) self.dcc_layers = nn.Sequential(_dcc_layers) def init_ctx_buf(self, batch_size, device): """ Returns an initialized context buffer for a given batch size. """ return torch.zeros( (batch_size, self.channels, (self.kernel_size - 1) * (2**self.num_layers - 1)), device=device) def forward(self, x, ctx_buf): """ Encodes input audio `x` into latent space, and aggregates contextual information in `ctx_buf`. Also generates new context buffer with updated context. Args: x: [B, in_channels, T] Input multi-channel audio. ctx_buf: {[B, channels, self.buf_length[0]], ...} A list of tensors holding context for each dilation causal conv layer. (len(ctx_buf) == self.num_layers) Returns: ctx_buf: {[B, channels, self.buf_length[0]], ...} Updated context buffer with output as the last element. """ T = x.shape[-1] # Sequence length for i in range(self.num_layers): buf_start_idx = self.buf_indices[i] buf_end_idx = self.buf_indices[i] + self.buf_lengths[i] # DCC input: concatenation of current output and context dcc_in = torch.cat( (ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1) # Push current output to the context buffer ctx_buf[..., buf_start_idx:buf_end_idx] = \ dcc_in[..., -self.buf_lengths[i]:] # Residual connection x = x + self.dcc_layers[i](dcc_in) return x, ctx_buf class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer): """ Adapted from: "https://github.com/alexmt-scale/causal-transformer-decoder/blob/" "0caf6ad71c46488f76d89845b0123d2550ef792f/" "causal_transformer_decoder/model.py#L77" """ def forward( self, tgt: Tensor, memory: Optional[Tensor] = None, chunk_size: int = 1 ) -> Tensor: tgt_last_tok = tgt[:, -chunk_size:, :] # self attention part tmp_tgt, sa_map = self.self_attn( tgt_last_tok, tgt, tgt, attn_mask=None, # not needed because we only care about the last token key_padding_mask=None, ) tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt) tgt_last_tok = self.norm1(tgt_last_tok) # encoder-decoder attention if memory is not None: tmp_tgt, ca_map = self.multihead_attn( tgt_last_tok, memory, memory, attn_mask=None, # Attend to the entire chunk key_padding_mask=None, ) tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt) tgt_last_tok = self.norm2(tgt_last_tok) # final feed-forward network tmp_tgt = self.linear2( self.dropout(self.activation(self.linear1(tgt_last_tok))) ) tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt) tgt_last_tok = self.norm3(tgt_last_tok) return tgt_last_tok, sa_map, ca_map class CausalTransformerDecoder(nn.Module): """ A casual transformer decoder which decodes input vectors using precisely `ctx_len` past vectors in the sequence, and using no future vectors at all. """ def __init__(self, model_dim, ctx_len, chunk_size, num_layers, nhead, use_pos_enc, ff_dim): super(CausalTransformerDecoder, self).__init__() self.num_layers = num_layers self.model_dim = model_dim self.ctx_len = ctx_len self.chunk_size = chunk_size self.nhead = nhead self.use_pos_enc = use_pos_enc self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size) self.pos_enc = PositionalEncoding(model_dim, max_len=200) self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer( d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim, batch_first=True) for _ in range(num_layers)]) def init_ctx_buf(self, batch_size, device): return torch.zeros( (batch_size, self.num_layers + 1, self.ctx_len, self.model_dim), device=device) def _causal_unfold(self, x): """ Unfolds the sequence into a batch of sequences prepended with `ctx_len` previous values. Args: x: [B, ctx_len + L, C] ctx_len: int Returns: [B * L, ctx_len + 1, C] """ B, T, C = x.shape x = x.permute(0, 2, 1) # [B, C, ctx_len + L] x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1] x = x.permute(0, 2, 1) x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size) x = x.reshape(-1, C, self.ctx_len + self.chunk_size) x = x.permute(0, 2, 1) return x def forward(self, tgt, mem, ctx_buf, probe=False): """ Args: x: [B, model_dim, T] ctx_buf: [B, num_layers, model_dim, ctx_len] """ mem, _ = mod_pad(mem, self.chunk_size, (0, 0)) tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0)) # Input sequence length B, C, T = tgt.shape tgt = tgt.permute(0, 2, 1) mem = mem.permute(0, 2, 1) # Prepend mem with the context mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1) ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :] mem_ctx = self._causal_unfold(mem) if self.use_pos_enc: mem_ctx = mem_ctx + self.pos_enc(mem_ctx) # Attention chunk size: required to ensure the model # wouldn't trigger an out-of-memory error when working # on long sequences. K = 1000 for i, tf_dec_layer in enumerate(self.tf_dec_layers): # Update the tgt with context tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1) ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :] # Compute encoded output tgt_ctx = self._causal_unfold(tgt) if self.use_pos_enc and i == 0: tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx) tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :] for i in range(int(math.ceil(tgt.shape[0] / K))): tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer( tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K], self.chunk_size) tgt = tgt.reshape(B, T, C) tgt = tgt.permute(0, 2, 1) if mod != 0: tgt = tgt[..., :-mod] return tgt, ctx_buf class MaskNet(nn.Module): def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len, dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj): super(MaskNet, self).__init__() self.skip_connection = skip_connection self.proj = proj # Encoder based on dilated causal convolutions. self.encoder = DilatedCausalConvEncoder(channels=enc_dim, num_layers=num_enc_layers) # Project between encoder and decoder dimensions self.proj_e2d_e = nn.Sequential( nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, groups=dec_dim), nn.ReLU()) self.proj_e2d_l = nn.Sequential( nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0, groups=dec_dim), nn.ReLU()) self.proj_d2e = nn.Sequential( nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0, groups=dec_dim), nn.ReLU()) # Transformer decoder that operates on chunks of size # buffer size. self.decoder = CausalTransformerDecoder( model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size, num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc, ff_dim=2 * dec_dim) def forward(self, x, l, enc_buf, dec_buf): """ Generates a mask based on encoded input `e` and the one-hot label `label`. Args: x: [B, C, T] Input audio sequence l: [B, C] Label embedding ctx_buf: {[B, C, ], ...} List of context buffers maintained by DCC encoder """ # Enocder the label integrated input e, enc_buf = self.encoder(x, enc_buf) # Label integration l = l.unsqueeze(2) * e # Project to `dec_dim` dimensions if self.proj: e = self.proj_e2d_e(e) m = self.proj_e2d_l(l) # Cross-attention to predict the mask m, dec_buf = self.decoder(m, e, dec_buf) else: # Cross-attention to predict the mask m, dec_buf = self.decoder(l, e, dec_buf) # Project mask to encoder dimensions if self.proj: m = self.proj_d2e(m) # Final mask after residual connection if self.skip_connection: m = l + m return m, enc_buf, dec_buf class Net(nn.Module): def __init__(self, label_len, L=8, enc_dim=512, num_enc_layers=10, dec_dim=256, dec_buf_len=100, num_dec_layers=2, dec_chunk_size=72, out_buf_len=2, use_pos_enc=True, skip_connection=True, proj=True, lookahead=True): super(Net, self).__init__() self.L = L self.out_buf_len = out_buf_len self.enc_dim = enc_dim self.lookahead = lookahead # Input conv to convert input audio to a latent representation kernel_size = 3 * L if lookahead else L self.in_conv = nn.Sequential( nn.Conv1d(in_channels=1, out_channels=enc_dim, kernel_size=kernel_size, stride=L, padding=0, bias=False), nn.ReLU()) # Label embedding layer self.label_embedding = nn.Sequential( nn.Linear(label_len, 512), nn.LayerNorm(512), nn.ReLU(), nn.Linear(512, enc_dim), nn.LayerNorm(enc_dim), nn.ReLU()) # Mask generator self.mask_gen = MaskNet( enc_dim=enc_dim, num_enc_layers=num_enc_layers, dec_dim=dec_dim, dec_buf_len=dec_buf_len, dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers, use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj) # Output conv layer self.out_conv = nn.Sequential( nn.ConvTranspose1d( in_channels=enc_dim, out_channels=1, kernel_size=(out_buf_len + 1) * L, stride=L, padding=out_buf_len * L, bias=False), nn.Tanh()) def init_buffers(self, batch_size, device): enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device) dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device) out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len, device=device) return enc_buf, dec_buf, out_buf def forward(self, x, label, init_enc_buf=None, init_dec_buf=None, init_out_buf=None, pad=True): """ Extracts the audio corresponding to the `label` in the given `mixture`. Generates `chunk_size` samples per iteration. Args: mixed: [B, n_mics, T] input audio mixture label: [B, num_labels] one hot label Returns: out: [B, n_spk, T] extracted audio with sounds corresponding to the `label` """ mod = 0 if pad: pad_size = (self.L, self.L) if self.lookahead else (0, 0) x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size) if init_enc_buf is None or init_dec_buf is None or init_out_buf is None: assert init_enc_buf is None and \ init_dec_buf is None and \ init_out_buf is None, \ "Both buffers have to initialized, or " \ "both of them have to be None." enc_buf, dec_buf, out_buf = self.init_buffers( x.shape[0], x.device) else: enc_buf, dec_buf, out_buf = \ init_enc_buf, init_dec_buf, init_out_buf # Generate latent space representation of the input x = self.in_conv(x) # Generate label embedding l = self.label_embedding(label) # [B, label_len] --> [B, channels] # Generate mask corresponding to the label m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf) # Apply mask and decode x = x * m x = torch.cat((out_buf, x), dim=-1) out_buf = x[..., -self.out_buf_len:] x = self.out_conv(x) # Remove mod padding, if present. if mod != 0: x = x[:, :, :-mod] if init_enc_buf is None: return x else: return x, enc_buf, dec_buf, out_buf # Define optimizer, loss and metrics def optimizer(model, data_parallel=False, **kwargs): return optim.Adam(model.parameters(), **kwargs) def loss(pred, tgt): return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() def metrics(mixed, output, gt): """ Function to compute metrics """ metrics = {} def metric_i(metric, src, pred, tgt): _vals = [] for s, t, p in zip(src, tgt, pred): _vals.append((metric(p, t) - metric(s, t)).cpu().item()) return _vals for m_fn in [snr, si_snr]: metrics[m_fn.__name__] = metric_i(m_fn, mixed[:, :gt.shape[1], :], output, gt) return metrics