import math from typing import Dict, List, Tuple import torch from torch import Tensor, nn from torch.nn import functional as F from .unet import UNet def batchify(tensor: Tensor, T: int) -> Tensor: """ partition tensor into segments of length T, zero pad any ragged samples Args: tensor(Tensor): BxCxFxL Returns: tensor of size (B*[L/T] x C x F x T) """ # Zero pad the original tensor to an even multiple of T orig_size = tensor.size(-1) new_size = math.ceil(orig_size / T) * T tensor = F.pad(tensor, [0, new_size - orig_size]) # Partition the tensor into multiple samples of length T and stack them into a batch return torch.cat(torch.split(tensor, T, dim=-1), dim=0) class Splitter(nn.Module): def __init__(self, stem_names: List[str] = None): super(Splitter, self).__init__() # stft config self.F = 1024 self.T = 512 self.win_length = 4096 self.hop_length = 1024 self.win = nn.Parameter(torch.hann_window(self.win_length), requires_grad=False) self.stems = nn.ModuleDict({'vocals': UNet(in_channels=2), 'accompaniment': UNet(in_channels=2)}) def compute_stft(self, wav: Tensor) -> Tuple[Tensor, Tensor]: """ Computes stft feature from wav Args: wav (Tensor): B x L """ stft = torch.stft( wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win, center=True, return_complex=False, pad_mode="constant", ) # only keep freqs smaller than self.F stft = stft[:, : self.F, :, :] real = stft[:, :, :, 0] im = stft[:, :, :, 1] mag = torch.sqrt(real ** 2 + im ** 2) return stft, mag def inverse_stft(self, stft: Tensor) -> Tensor: """Inverses stft to wave form""" pad = self.win_length // 2 + 1 - stft.size(1) stft = F.pad(stft, (0, 0, 0, 0, 0, pad)) wav = torch.istft( stft, self.win_length, hop_length=self.hop_length, center=True, window=self.win, ) return wav.detach() def forward(self, wav: Tensor) -> Dict[str, Tensor]: """ Separates stereo wav into different tracks (1 predicted track per stem) Args: wav (tensor): 2 x L Returns: masked stfts by track name """ # stft - 2 X F x L x 2 # stft_mag - 2 X F x L stft, stft_mag = self.compute_stft(wav.squeeze()) L = stft.size(2) # 1 x 2 x F x T stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2]) stft_mag = batchify(stft_mag, self.T) # B x 2 x F x T stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F # compute stems' mask masks = {name: net(stft_mag) for name, net in self.stems.items()} # compute denominator mask_sum = sum([m ** 2 for m in masks.values()]) mask_sum += 1e-10 def apply_mask(mask): mask = (mask ** 2 + 1e-10 / 2) / (mask_sum) mask = mask.transpose(2, 3) # B x 2 X F x T mask = torch.cat(torch.split(mask, 1, dim=0), dim=3) mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1 stft_masked = stft * mask return stft_masked return {name: apply_mask(m) for name, m in masks.items()} def separate(self, wav: Tensor) -> Dict[str, Tensor]: """ Separates stereo wav into different tracks (1 predicted track per stem) Args: wav (tensor): 2 x L Returns: wavs by track name """ stft_masks = self.forward(wav) return { name: self.inverse_stft(stft_masked) for name, stft_masked in stft_masks.items() } @classmethod def from_pretrained(cls, model_path: str): checkpoint = torch.load(model_path) model = cls() model.load_state_dict(checkpoint) return model