|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
import torch as th |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from . import dsp |
|
|
|
|
|
class Remix(nn.Module): |
|
"""Remix. |
|
Mixes different noises with clean speech within a given batch |
|
""" |
|
|
|
def forward(self, sources): |
|
noise, clean = sources |
|
bs, *other = noise.shape |
|
device = noise.device |
|
perm = th.argsort(th.rand(bs, device=device), dim=0) |
|
return th.stack([noise[perm], clean]) |
|
|
|
|
|
class RevEcho(nn.Module): |
|
""" |
|
Hacky Reverb but runs on GPU without slowing down training. |
|
This reverb adds a succession of attenuated echos of the input |
|
signal to itself. Intuitively, the delay of the first echo will happen |
|
after roughly 2x the radius of the room and is controlled by `first_delay`. |
|
Then RevEcho keeps adding echos with the same delay and further attenuation |
|
until the amplitude ratio between the last and first echo is 1e-3. |
|
The attenuation factor and the number of echos to adds is controlled |
|
by RT60 (measured in seconds). RT60 is the average time to get to -60dB |
|
(remember volume is measured over the squared amplitude so this matches |
|
the 1e-3 ratio). |
|
|
|
At each call to RevEcho, `first_delay`, `initial` and `RT60` are |
|
sampled from their range. Then, to prevent this reverb from being too regular, |
|
the delay time is resampled uniformly within `first_delay +- 10%`, |
|
as controlled by the `jitter` parameter. Finally, for a denser reverb, |
|
multiple trains of echos are added with different jitter noises. |
|
|
|
Args: |
|
- initial: amplitude of the first echo as a fraction |
|
of the input signal. For each sample, actually sampled from |
|
`[0, initial]`. Larger values means louder reverb. Physically, |
|
this would depend on the absorption of the room walls. |
|
- rt60: range of values to sample the RT60 in seconds, i.e. |
|
after RT60 seconds, the echo amplitude is 1e-3 of the first echo. |
|
The default values follow the recommendations of |
|
https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, Section 2.4. |
|
Physically this would also be related to the absorption of the |
|
room walls and there is likely a relation between `RT60` and |
|
`initial`, which we ignore here. |
|
- first_delay: range of values to sample the first echo delay in seconds. |
|
The default values are equivalent to sampling a room of 3 to 10 meters. |
|
- repeat: how many train of echos with differents jitters to add. |
|
Higher values means a denser reverb. |
|
- jitter: jitter used to make each repetition of the reverb echo train |
|
slightly different. For instance a jitter of 0.1 means |
|
the delay between two echos will be in the range `first_delay +- 10%`, |
|
with the jittering noise being resampled after each single echo. |
|
- keep_clean: fraction of the reverb of the clean speech to add back |
|
to the ground truth. 0 = dereverberation, 1 = no dereverberation. |
|
- sample_rate: sample rate of the input signals. |
|
""" |
|
|
|
def __init__(self, proba=0.5, initial=0.3, rt60=(0.3, 1.3), first_delay=(0.01, 0.03), |
|
repeat=3, jitter=0.1, keep_clean=0.1, sample_rate=16000): |
|
super().__init__() |
|
self.proba = proba |
|
self.initial = initial |
|
self.rt60 = rt60 |
|
self.first_delay = first_delay |
|
self.repeat = repeat |
|
self.jitter = jitter |
|
self.keep_clean = keep_clean |
|
self.sample_rate = sample_rate |
|
|
|
def _reverb(self, source, initial, first_delay, rt60): |
|
""" |
|
Return the reverb for a single source. |
|
""" |
|
length = source.shape[-1] |
|
reverb = th.zeros_like(source) |
|
for _ in range(self.repeat): |
|
frac = 1 |
|
echo = initial * source |
|
while frac > 1e-3: |
|
|
|
jitter = 1 + self.jitter * random.uniform(-1, 1) |
|
delay = min( |
|
1 + int(jitter * first_delay * self.sample_rate), |
|
length) |
|
|
|
echo = F.pad(echo[:, :, :-delay], (delay, 0)) |
|
reverb += echo |
|
|
|
|
|
jitter = 1 + self.jitter * random.uniform(-1, 1) |
|
|
|
|
|
attenuation = 10**(-3 * jitter * first_delay / rt60) |
|
echo *= attenuation |
|
frac *= attenuation |
|
return reverb |
|
|
|
def forward(self, wav): |
|
if random.random() >= self.proba: |
|
return wav |
|
noise, clean = wav |
|
|
|
initial = random.random() * self.initial |
|
first_delay = random.uniform(*self.first_delay) |
|
rt60 = random.uniform(*self.rt60) |
|
|
|
reverb_noise = self._reverb(noise, initial, first_delay, rt60) |
|
|
|
noise += reverb_noise |
|
reverb_clean = self._reverb(clean, initial, first_delay, rt60) |
|
|
|
clean += self.keep_clean * reverb_clean |
|
noise += (1 - self.keep_clean) * reverb_clean |
|
|
|
return th.stack([noise, clean]) |
|
|
|
|
|
class BandMask(nn.Module): |
|
"""BandMask. |
|
Maskes bands of frequencies. Similar to Park, Daniel S., et al. |
|
"Specaugment: A simple data augmentation method for automatic speech recognition." |
|
(https://arxiv.org/pdf/1904.08779.pdf) but over the waveform. |
|
""" |
|
|
|
def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000): |
|
"""__init__. |
|
|
|
:param maxwidth: the maximum width to remove |
|
:param bands: number of bands |
|
:param sample_rate: signal sample rate |
|
""" |
|
super().__init__() |
|
self.maxwidth = maxwidth |
|
self.bands = bands |
|
self.sample_rate = sample_rate |
|
|
|
def forward(self, wav): |
|
bands = self.bands |
|
bandwidth = int(abs(self.maxwidth) * bands) |
|
mels = dsp.mel_frequencies(bands, 40, self.sample_rate/2) / self.sample_rate |
|
low = random.randrange(bands) |
|
high = random.randrange(low, min(bands, low + bandwidth)) |
|
filters = dsp.LowPassFilters([mels[low], mels[high]]).to(wav.device) |
|
low, midlow = filters(wav) |
|
|
|
out = wav - midlow + low |
|
return out |
|
|
|
|
|
class Shift(nn.Module): |
|
"""Shift.""" |
|
|
|
def __init__(self, shift=8192, same=False): |
|
"""__init__. |
|
|
|
:param shift: randomly shifts the signals up to a given factor |
|
:param same: shifts both clean and noisy files by the same factor |
|
""" |
|
super().__init__() |
|
self.shift = shift |
|
self.same = same |
|
|
|
def forward(self, wav): |
|
sources, batch, channels, length = wav.shape |
|
length = length - self.shift |
|
if self.shift > 0: |
|
if not self.training: |
|
wav = wav[..., :length] |
|
else: |
|
offsets = th.randint( |
|
self.shift, |
|
[1 if self.same else sources, batch, 1, 1], device=wav.device) |
|
offsets = offsets.expand(sources, -1, channels, -1) |
|
indexes = th.arange(length, device=wav.device) |
|
wav = wav.gather(3, indexes + offsets) |
|
return wav |
|
|