Spaces:
Runtime error
Runtime error
File size: 3,947 Bytes
4b8361a 0030bc6 4b8361a 0030bc6 b6800ef 0030bc6 b6800ef 0030bc6 4b8361a b6800ef 0030bc6 4b8361a 0030bc6 b6800ef 0030bc6 4b8361a 0030bc6 b6800ef 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 4b8361a 0030bc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn
class AudioTrainingPipeline(torch.nn.Module):
def __init__(self,
input_freq=16000,
resample_freq=16000,
expected_duration=6,
freq_mask_size=10,
time_mask_size=80,
mask_count = 2,
snr_mean=6.0,
noise_path=None):
super().__init__()
self.input_freq = input_freq
self.snr_mean = snr_mean
self.mask_count = mask_count
self.noise = self.get_noise(noise_path)
self.resample = taT.Resample(input_freq,resample_freq)
self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
self.audio_to_spectrogram = AudioToSpectrogram(
sample_rate=resample_freq,
)
self.freq_mask = taT.FrequencyMasking(freq_mask_size)
self.time_mask = taT.TimeMasking(time_mask_size)
def get_noise(self, path) -> torch.Tensor:
if path is None:
return None
noise, sr = torchaudio.load(path)
if noise.shape[0] > 1:
noise = noise.mean(0, keepdim=True)
if sr != self.input_freq:
noise = taF.resample(noise,sr, self.input_freq)
return noise
def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
assert self.noise is not None, "Cannot add noise because a noise file was not provided."
num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
noise_power = noise.norm(p=2)
signal_power = waveform.norm(p=2)
snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
snr = torch.exp(snr_db / 10)
scale = snr * noise_power / signal_power
noisy_waveform = (scale * waveform + noise) / 2
return noisy_waveform
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
try:
waveform = self.resample(waveform)
except:
print("oops")
waveform = self.preprocess_waveform(waveform)
if self.noise is not None:
waveform = self.add_noise(waveform)
spec = self.audio_to_spectrogram(waveform)
# Spectrogram augmentation
for _ in range(self.mask_count):
spec = self.freq_mask(spec)
spec = self.time_mask(spec)
return spec
class WaveformPreprocessing(torch.nn.Module):
def __init__(self, expected_sample_length:int):
super().__init__()
self.expected_sample_length = expected_sample_length
def forward(self, waveform:torch.Tensor) -> torch.Tensor:
# Take out extra channels
if waveform.shape[0] > 1:
waveform = waveform.mean(0, keepdim=True)
# ensure it is the correct length
waveform = self._rectify_duration(waveform)
return waveform
def _rectify_duration(self,waveform:torch.Tensor):
expected_samples = self.expected_sample_length
sample_count = waveform.shape[1]
if expected_samples == sample_count:
return waveform
elif expected_samples > sample_count:
pad_amount = expected_samples - sample_count
return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
else:
return waveform[:,:expected_samples]
class AudioToSpectrogram(torch.nn.Module):
def __init__(
self,
sample_rate=16000,
):
super().__init__()
self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
self.to_db = taT.AmplitudeToDB()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = self.spec(waveform)
spectrogram = self.to_db(spectrogram)
return spectrogram |