TheComputerMan commited on
Commit
e06adea
1 Parent(s): 385b498

Upload AudioPreprocessor.py

Browse files
Files changed (1) hide show
  1. AudioPreprocessor.py +166 -0
AudioPreprocessor.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.core as lb
3
+ import librosa.display as lbd
4
+ import matplotlib.pyplot as plt
5
+ import numpy
6
+ import numpy as np
7
+ import pyloudnorm as pyln
8
+ import torch
9
+ from torchaudio.transforms import Resample
10
+
11
+
12
+ class AudioPreprocessor:
13
+
14
+ def __init__(self, input_sr, output_sr=None, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False, device="cpu"):
15
+ """
16
+ The parameters are by default set up to do well
17
+ on a 16kHz signal. A different sampling rate may
18
+ require different hop_length and n_fft (e.g.
19
+ doubling frequency --> doubling hop_length and
20
+ doubling n_fft)
21
+ """
22
+ self.cut_silence = cut_silence
23
+ self.device = device
24
+ self.sr = input_sr
25
+ self.new_sr = output_sr
26
+ self.hop_length = hop_length
27
+ self.n_fft = n_fft
28
+ self.mel_buckets = melspec_buckets
29
+ self.meter = pyln.Meter(input_sr)
30
+ self.final_sr = input_sr
31
+ if cut_silence:
32
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
33
+ # careful: assumes 16kHz or 8kHz audio
34
+ self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
35
+ model='silero_vad',
36
+ force_reload=False,
37
+ onnx=False,
38
+ verbose=False)
39
+ (self.get_speech_timestamps,
40
+ self.save_audio,
41
+ self.read_audio,
42
+ self.VADIterator,
43
+ self.collect_chunks) = utils
44
+ self.silero_model = self.silero_model.to(self.device)
45
+ if output_sr is not None and output_sr != input_sr:
46
+ self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
47
+ self.final_sr = output_sr
48
+ else:
49
+ self.resample = lambda x: x
50
+
51
+ def cut_silence_from_audio(self, audio):
52
+ """
53
+ https://github.com/snakers4/silero-vad
54
+ """
55
+ return self.collect_chunks(self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr), audio)
56
+
57
+ def to_mono(self, x):
58
+ """
59
+ make sure we deal with a 1D array
60
+ """
61
+ if len(x.shape) == 2:
62
+ return lb.to_mono(numpy.transpose(x))
63
+ else:
64
+ return x
65
+
66
+ def normalize_loudness(self, audio):
67
+ """
68
+ normalize the amplitudes according to
69
+ their decibels, so this should turn any
70
+ signal with different magnitudes into
71
+ the same magnitude by analysing loudness
72
+ """
73
+ loudness = self.meter.integrated_loudness(audio)
74
+ loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
75
+ peak = numpy.amax(numpy.abs(loud_normed))
76
+ peak_normed = numpy.divide(loud_normed, peak)
77
+ return peak_normed
78
+
79
+ def logmelfilterbank(self, audio, sampling_rate, fmin=40, fmax=8000, eps=1e-10):
80
+ """
81
+ Compute log-Mel filterbank
82
+
83
+ one day this could be replaced by torchaudio's internal log10(melspec(audio)), but
84
+ for some reason it gives slightly different results, so in order not to break backwards
85
+ compatibility, this is kept for now. If there is ever a reason to completely re-train
86
+ all models, this would be a good opportunity to make the switch.
87
+ """
88
+ if isinstance(audio, torch.Tensor):
89
+ audio = audio.numpy()
90
+ # get amplitude spectrogram
91
+ x_stft = librosa.stft(audio, n_fft=self.n_fft, hop_length=self.hop_length, win_length=None, window="hann", pad_mode="reflect")
92
+ spc = np.abs(x_stft).T
93
+ # get mel basis
94
+ fmin = 0 if fmin is None else fmin
95
+ fmax = sampling_rate / 2 if fmax is None else fmax
96
+ mel_basis = librosa.filters.mel(sampling_rate, self.n_fft, self.mel_buckets, fmin, fmax)
97
+ # apply log and return
98
+ return torch.Tensor(np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))).transpose(0, 1)
99
+
100
+ def normalize_audio(self, audio):
101
+ """
102
+ one function to apply them all in an
103
+ order that makes sense.
104
+ """
105
+ audio = self.to_mono(audio)
106
+ audio = self.normalize_loudness(audio)
107
+ audio = torch.Tensor(audio).to(self.device)
108
+ audio = self.resample(audio)
109
+ if self.cut_silence:
110
+ audio = self.cut_silence_from_audio(audio)
111
+ return audio.to("cpu")
112
+
113
+ def visualize_cleaning(self, unclean_audio):
114
+ """
115
+ displays Mel Spectrogram of unclean audio
116
+ and then displays Mel Spectrogram of the
117
+ cleaned version.
118
+ """
119
+ fig, ax = plt.subplots(nrows=2, ncols=1)
120
+ unclean_audio_mono = self.to_mono(unclean_audio)
121
+ unclean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=False).numpy()
122
+ clean_spec = self.audio_to_mel_spec_tensor(unclean_audio_mono, normalize=True).numpy()
123
+ lbd.specshow(unclean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[0], x_axis='time')
124
+ ax[0].set(title='Uncleaned Audio')
125
+ ax[0].label_outer()
126
+ if self.new_sr is not None:
127
+ lbd.specshow(clean_spec, sr=self.new_sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
128
+ else:
129
+ lbd.specshow(clean_spec, sr=self.sr, cmap='GnBu', y_axis='mel', ax=ax[1], x_axis='time')
130
+ ax[1].set(title='Cleaned Audio')
131
+ ax[1].label_outer()
132
+ plt.show()
133
+
134
+ def audio_to_wave_tensor(self, audio, normalize=True):
135
+ if normalize:
136
+ return self.normalize_audio(audio)
137
+ else:
138
+ if isinstance(audio, torch.Tensor):
139
+ return audio
140
+ else:
141
+ return torch.Tensor(audio)
142
+
143
+ def audio_to_mel_spec_tensor(self, audio, normalize=True, explicit_sampling_rate=None):
144
+ """
145
+ explicit_sampling_rate is for when
146
+ normalization has already been applied
147
+ and that included resampling. No way
148
+ to detect the current sr of the incoming
149
+ audio
150
+ """
151
+ if explicit_sampling_rate is None:
152
+ if normalize:
153
+ audio = self.normalize_audio(audio)
154
+ return self.logmelfilterbank(audio=audio, sampling_rate=self.final_sr)
155
+ return self.logmelfilterbank(audio=audio, sampling_rate=self.sr)
156
+ if normalize:
157
+ audio = self.normalize_audio(audio)
158
+ return self.logmelfilterbank(audio=audio, sampling_rate=explicit_sampling_rate)
159
+
160
+
161
+ if __name__ == '__main__':
162
+ import soundfile
163
+
164
+ wav, sr = soundfile.read("../audios/test.wav")
165
+ ap = AudioPreprocessor(input_sr=sr, output_sr=16000)
166
+ ap.visualize_cleaning(wav)