DeepLearning101's picture
Upload 17 files
109bb65
raw
history blame
7.86 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import random
import torch as th
from torch import nn
from torch.nn import functional as F
from . import dsp
class Remix(nn.Module):
"""Remix.
Mixes different noises with clean speech within a given batch
"""
def forward(self, sources):
noise, clean = sources
bs, *other = noise.shape
device = noise.device
perm = th.argsort(th.rand(bs, device=device), dim=0)
return th.stack([noise[perm], clean])
class RevEcho(nn.Module):
"""
Hacky Reverb but runs on GPU without slowing down training.
This reverb adds a succession of attenuated echos of the input
signal to itself. Intuitively, the delay of the first echo will happen
after roughly 2x the radius of the room and is controlled by `first_delay`.
Then RevEcho keeps adding echos with the same delay and further attenuation
until the amplitude ratio between the last and first echo is 1e-3.
The attenuation factor and the number of echos to adds is controlled
by RT60 (measured in seconds). RT60 is the average time to get to -60dB
(remember volume is measured over the squared amplitude so this matches
the 1e-3 ratio).
At each call to RevEcho, `first_delay`, `initial` and `RT60` are
sampled from their range. Then, to prevent this reverb from being too regular,
the delay time is resampled uniformly within `first_delay +- 10%`,
as controlled by the `jitter` parameter. Finally, for a denser reverb,
multiple trains of echos are added with different jitter noises.
Args:
- initial: amplitude of the first echo as a fraction
of the input signal. For each sample, actually sampled from
`[0, initial]`. Larger values means louder reverb. Physically,
this would depend on the absorption of the room walls.
- rt60: range of values to sample the RT60 in seconds, i.e.
after RT60 seconds, the echo amplitude is 1e-3 of the first echo.
The default values follow the recommendations of
https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, Section 2.4.
Physically this would also be related to the absorption of the
room walls and there is likely a relation between `RT60` and
`initial`, which we ignore here.
- first_delay: range of values to sample the first echo delay in seconds.
The default values are equivalent to sampling a room of 3 to 10 meters.
- repeat: how many train of echos with differents jitters to add.
Higher values means a denser reverb.
- jitter: jitter used to make each repetition of the reverb echo train
slightly different. For instance a jitter of 0.1 means
the delay between two echos will be in the range `first_delay +- 10%`,
with the jittering noise being resampled after each single echo.
- keep_clean: fraction of the reverb of the clean speech to add back
to the ground truth. 0 = dereverberation, 1 = no dereverberation.
- sample_rate: sample rate of the input signals.
"""
def __init__(self, proba=0.5, initial=0.3, rt60=(0.3, 1.3), first_delay=(0.01, 0.03),
repeat=3, jitter=0.1, keep_clean=0.1, sample_rate=16000):
super().__init__()
self.proba = proba
self.initial = initial
self.rt60 = rt60
self.first_delay = first_delay
self.repeat = repeat
self.jitter = jitter
self.keep_clean = keep_clean
self.sample_rate = sample_rate
def _reverb(self, source, initial, first_delay, rt60):
"""
Return the reverb for a single source.
"""
length = source.shape[-1]
reverb = th.zeros_like(source)
for _ in range(self.repeat):
frac = 1 # what fraction of the first echo amplitude is still here
echo = initial * source
while frac > 1e-3:
# First jitter noise for the delay
jitter = 1 + self.jitter * random.uniform(-1, 1)
delay = min(
1 + int(jitter * first_delay * self.sample_rate),
length)
# Delay the echo in time by padding with zero on the left
echo = F.pad(echo[:, :, :-delay], (delay, 0))
reverb += echo
# Second jitter noise for the attenuation
jitter = 1 + self.jitter * random.uniform(-1, 1)
# we want, with `d` the attenuation, d**(rt60 / first_ms) = 1e-3
# i.e. log10(d) = -3 * first_ms / rt60, so that
attenuation = 10**(-3 * jitter * first_delay / rt60)
echo *= attenuation
frac *= attenuation
return reverb
def forward(self, wav):
if random.random() >= self.proba:
return wav
noise, clean = wav
# Sample characteristics for the reverb
initial = random.random() * self.initial
first_delay = random.uniform(*self.first_delay)
rt60 = random.uniform(*self.rt60)
reverb_noise = self._reverb(noise, initial, first_delay, rt60)
# Reverb for the noise is always added back to the noise
noise += reverb_noise
reverb_clean = self._reverb(clean, initial, first_delay, rt60)
# Split clean reverb among the clean speech and noise
clean += self.keep_clean * reverb_clean
noise += (1 - self.keep_clean) * reverb_clean
return th.stack([noise, clean])
class BandMask(nn.Module):
"""BandMask.
Maskes bands of frequencies. Similar to Park, Daniel S., et al.
"Specaugment: A simple data augmentation method for automatic speech recognition."
(https://arxiv.org/pdf/1904.08779.pdf) but over the waveform.
"""
def __init__(self, maxwidth=0.2, bands=120, sample_rate=16_000):
"""__init__.
:param maxwidth: the maximum width to remove
:param bands: number of bands
:param sample_rate: signal sample rate
"""
super().__init__()
self.maxwidth = maxwidth
self.bands = bands
self.sample_rate = sample_rate
def forward(self, wav):
bands = self.bands
bandwidth = int(abs(self.maxwidth) * bands)
mels = dsp.mel_frequencies(bands, 40, self.sample_rate/2) / self.sample_rate
low = random.randrange(bands)
high = random.randrange(low, min(bands, low + bandwidth))
filters = dsp.LowPassFilters([mels[low], mels[high]]).to(wav.device)
low, midlow = filters(wav)
# band pass filtering
out = wav - midlow + low
return out
class Shift(nn.Module):
"""Shift."""
def __init__(self, shift=8192, same=False):
"""__init__.
:param shift: randomly shifts the signals up to a given factor
:param same: shifts both clean and noisy files by the same factor
"""
super().__init__()
self.shift = shift
self.same = same
def forward(self, wav):
sources, batch, channels, length = wav.shape
length = length - self.shift
if self.shift > 0:
if not self.training:
wav = wav[..., :length]
else:
offsets = th.randint(
self.shift,
[1 if self.same else sources, batch, 1, 1], device=wav.device)
offsets = offsets.expand(sources, -1, channels, -1)
indexes = th.arange(length, device=wav.device)
wav = wav.gather(3, indexes + offsets)
return wav