# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import torch as th from torch import nn class Shift(nn.Module): """ Randomly shift audio in time by up to `shift` samples. """ def __init__(self, shift=8192): super().__init__() self.shift = shift def forward(self, wav): batch, sources, channels, time = wav.size() length = time - self.shift if self.shift > 0: if not self.training: wav = wav[..., :length] else: offsets = th.randint(self.shift, [batch, sources, 1, 1], device=wav.device) offsets = offsets.expand(-1, -1, channels, -1) indexes = th.arange(length, device=wav.device) wav = wav.gather(3, indexes + offsets) return wav class FlipChannels(nn.Module): """ Flip left-right channels. """ def forward(self, wav): batch, sources, channels, time = wav.size() if self.training and wav.size(2) == 2: left = th.randint(2, (batch, sources, 1, 1), device=wav.device) left = left.expand(-1, -1, -1, time) right = 1 - left wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) return wav class FlipSign(nn.Module): """ Random sign flip. """ def forward(self, wav): batch, sources, channels, time = wav.size() if self.training: signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32) wav = wav * (2 * signs - 1) return wav class Remix(nn.Module): """ Shuffle sources to make new mixes. """ def __init__(self, group_size=4): """ Shuffle sources within one batch. Each batch is divided into groups of size `group_size` and shuffling is done within each group separatly. This allow to keep the same probability distribution no matter the number of GPUs. Without this grouping, using more GPUs would lead to a higher probability of keeping two sources from the same track together which can impact performance. """ super().__init__() self.group_size = group_size def forward(self, wav): batch, streams, channels, time = wav.size() device = wav.device if self.training: group_size = self.group_size or batch if batch % group_size != 0: raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") groups = batch // group_size wav = wav.view(groups, group_size, streams, channels, time) permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device), dim=1) wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) wav = wav.view(batch, streams, channels, time) return wav class Scale(nn.Module): def __init__(self, proba=1., min=0.25, max=1.25): super().__init__() self.proba = proba self.min = min self.max = max def forward(self, wav): batch, streams, channels, time = wav.size() device = wav.device if self.training and random.random() < self.proba: scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max) wav *= scales return wav