File size: 3,947 Bytes
4b8361a
0030bc6
4b8361a
 
 
0030bc6
 
 
 
 
 
 
 
b6800ef
 
0030bc6
 
 
 
b6800ef
0030bc6
 
 
 
 
 
 
4b8361a
 
b6800ef
 
 
 
0030bc6
 
 
 
 
4b8361a
0030bc6
b6800ef
0030bc6
 
 
 
 
 
 
 
 
4b8361a
0030bc6
 
 
 
 
 
b6800ef
 
0030bc6
4b8361a
0030bc6
 
 
 
 
4b8361a
 
0030bc6
4b8361a
0030bc6
4b8361a
0030bc6
4b8361a
 
 
 
0030bc6
 
 
4b8361a
0030bc6
 
 
4b8361a
0030bc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b8361a
 
0030bc6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchaudio
from torchaudio import transforms as taT, functional as taF
import torch.nn as nn

class AudioTrainingPipeline(torch.nn.Module):
    def __init__(self, 
            input_freq=16000,
            resample_freq=16000,
            expected_duration=6,
            freq_mask_size=10,
            time_mask_size=80,
            mask_count = 2,
            snr_mean=6.0,
            noise_path=None):
        super().__init__()
        self.input_freq = input_freq
        self.snr_mean = snr_mean
        self.mask_count = mask_count
        self.noise = self.get_noise(noise_path)
        self.resample = taT.Resample(input_freq,resample_freq)
        self.preprocess_waveform = WaveformPreprocessing(resample_freq * expected_duration)
        self.audio_to_spectrogram = AudioToSpectrogram(
            sample_rate=resample_freq,
        )
        self.freq_mask = taT.FrequencyMasking(freq_mask_size)
        self.time_mask = taT.TimeMasking(time_mask_size)
        

    def get_noise(self, path) -> torch.Tensor:
        if path is None:
            return None
        noise, sr = torchaudio.load(path)
        if noise.shape[0] > 1:
            noise = noise.mean(0, keepdim=True)
        if sr != self.input_freq:
            noise = taF.resample(noise,sr, self.input_freq)
        return noise

    def add_noise(self, waveform:torch.Tensor) -> torch.Tensor:
        assert self.noise is not None, "Cannot add noise because a noise file was not provided."
        num_repeats = waveform.shape[1] // self.noise.shape[1] + 1
        noise = self.noise.repeat(1,num_repeats)[:, :waveform.shape[1]]
        noise_power = noise.norm(p=2)
        signal_power = waveform.norm(p=2)
        snr_db = torch.normal(self.snr_mean, 1.5, (1,)).clamp_min(1.0)
        snr = torch.exp(snr_db / 10)
        scale = snr * noise_power / signal_power
        noisy_waveform = (scale * waveform + noise) / 2
        return noisy_waveform

    def forward(self, waveform:torch.Tensor) -> torch.Tensor:
        try:
            waveform = self.resample(waveform)
        except:
            print("oops")
        waveform = self.preprocess_waveform(waveform)
        if self.noise is not None:
            waveform = self.add_noise(waveform)
        spec = self.audio_to_spectrogram(waveform)

        # Spectrogram augmentation
        for _ in range(self.mask_count):
            spec = self.freq_mask(spec)
            spec = self.time_mask(spec)
        return spec


class WaveformPreprocessing(torch.nn.Module):

    def __init__(self, expected_sample_length:int):
        super().__init__()
        self.expected_sample_length = expected_sample_length
        


    def forward(self, waveform:torch.Tensor) -> torch.Tensor:
        # Take out extra channels
        if waveform.shape[0] > 1:
            waveform = waveform.mean(0, keepdim=True)

        # ensure it is the correct length
        waveform = self._rectify_duration(waveform)
        return waveform


    def _rectify_duration(self,waveform:torch.Tensor):
        expected_samples = self.expected_sample_length
        sample_count = waveform.shape[1]
        if expected_samples == sample_count:
            return waveform
        elif expected_samples > sample_count:
            pad_amount = expected_samples - sample_count
            return torch.nn.functional.pad(waveform, (0, pad_amount),mode="constant", value=0.0)
        else:
            return waveform[:,:expected_samples]


class AudioToSpectrogram(torch.nn.Module):
    def __init__(
        self,
        sample_rate=16000,
    ):
        super().__init__()

        self.spec = taT.MelSpectrogram(sample_rate=sample_rate, n_mels=128, n_fft=1024) # TODO: Change mels to 64
        self.to_db = taT.AmplitudeToDB()

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        spectrogram = self.spec(waveform)
        spectrogram = self.to_db(spectrogram)
        return spectrogram