File size: 6,445 Bytes
9e275b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy
import pyloudnorm as pyln
import torch
from torchaudio.transforms import MelSpectrogram
from torchaudio.transforms import Resample


class AudioPreprocessor:

    def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"):
        """
        The parameters are by default set up to do well
        on a 16kHz signal. A different sampling rate may
        require different hop_length and n_fft (e.g.
        doubling frequency --> doubling hop_length and
        doubling n_fft)
        """
        self.cut_silence = cut_silence
        self.do_loudnorm = do_loudnorm
        self.device = device
        self.input_sr = input_sr
        self.output_sr = output_sr
        self.meter = pyln.Meter(input_sr)
        self.final_sr = input_sr
        self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device)
        if cut_silence:
            torch.hub._validate_not_a_forked_repo = lambda a, b, c: True  # torch 1.9 has a bug in the hub loading, this is a workaround
            # careful: assumes 16kHz or 8kHz audio
            self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                                      model='silero_vad',
                                                      force_reload=False,
                                                      onnx=False,
                                                      verbose=False)
            (self.get_speech_timestamps,
             self.save_audio,
             self.read_audio,
             self.VADIterator,
             self.collect_chunks) = utils
            torch.set_grad_enabled(True)  # finding this issue was very infuriating: silero sets
            # this to false globally during model loading rather than using inference mode or no_grad
            self.silero_model = self.silero_model.to(self.device)
        if output_sr is not None and output_sr != input_sr:
            self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
            self.final_sr = output_sr
        else:
            self.resample = lambda x: x

    def cut_leading_and_trailing_silence(self, audio):
        """
        https://github.com/snakers4/silero-vad
        """
        with torch.inference_mode():
            speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr)
        try:
            result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
            return result
        except IndexError:
            print("Audio might be too short to cut silences from front and back.")
        return audio

    def normalize_loudness(self, audio):
        """
        normalize the amplitudes according to
        their decibels, so this should turn any
        signal with different magnitudes into
        the same magnitude by analysing loudness
        """
        try:
            loudness = self.meter.integrated_loudness(audio)
        except ValueError:
            # if the audio is too short, a value error will arise
            return audio
        loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
        peak = numpy.amax(numpy.abs(loud_normed))
        peak_normed = numpy.divide(loud_normed, peak)
        return peak_normed

    def normalize_audio(self, audio):
        """
        one function to apply them all in an
        order that makes sense.
        """
        if self.do_loudnorm:
            audio = self.normalize_loudness(audio)
        audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
        audio = self.resample(audio)
        if self.cut_silence:
            audio = self.cut_leading_and_trailing_silence(audio)
        return audio

    def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None):
        """
        explicit_sampling_rate is for when
        normalization has already been applied
        and that included resampling. No way
        to detect the current input_sr of the incoming
        audio
        """
        if type(audio) != torch.tensor and type(audio) != torch.Tensor:
            audio = torch.tensor(audio, device=self.device)
        if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr:
            return self.wave_to_spectrogram(audio.float())
        else:
            if explicit_sampling_rate != self.input_sr:
                print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.")
                self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device)
                self.input_sr = explicit_sampling_rate
            audio = self.resample(audio.float())
            return self.wave_to_spectrogram(audio)


class LogMelSpec(torch.nn.Module):
    def __init__(self, sr, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.spec = MelSpectrogram(sample_rate=sr,
                                   n_fft=1024,
                                   win_length=1024,
                                   hop_length=256,
                                   f_min=40.0,
                                   f_max=sr // 2,
                                   pad=0,
                                   n_mels=128,
                                   power=2.0,
                                   normalized=False,
                                   center=True,
                                   pad_mode='reflect',
                                   mel_scale='htk')

    def forward(self, audio):
        melspec = self.spec(audio.float())
        zero_mask = melspec == 0
        melspec[zero_mask] = 1e-8
        logmelspec = torch.log10(melspec)
        return logmelspec


if __name__ == '__main__':
    import soundfile

    wav, sr = soundfile.read("../audios/ad00_0004.wav")
    ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True)
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6))
    import librosa.display as lbd

    lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(),
                 ax=ax,
                 sr=16000,
                 cmap='GnBu',
                 y_axis='features',
                 x_axis=None,
                 hop_length=256)
    plt.show()