Flux9665's picture
use explicit code instead of relying on release download
9e275b8
raw
history blame
6.45 kB
import numpy
import pyloudnorm as pyln
import torch
from torchaudio.transforms import MelSpectrogram
from torchaudio.transforms import Resample
class AudioPreprocessor:
def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"):
"""
The parameters are by default set up to do well
on a 16kHz signal. A different sampling rate may
require different hop_length and n_fft (e.g.
doubling frequency --> doubling hop_length and
doubling n_fft)
"""
self.cut_silence = cut_silence
self.do_loudnorm = do_loudnorm
self.device = device
self.input_sr = input_sr
self.output_sr = output_sr
self.meter = pyln.Meter(input_sr)
self.final_sr = input_sr
self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device)
if cut_silence:
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround
# careful: assumes 16kHz or 8kHz audio
self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
onnx=False,
verbose=False)
(self.get_speech_timestamps,
self.save_audio,
self.read_audio,
self.VADIterator,
self.collect_chunks) = utils
torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets
# this to false globally during model loading rather than using inference mode or no_grad
self.silero_model = self.silero_model.to(self.device)
if output_sr is not None and output_sr != input_sr:
self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
self.final_sr = output_sr
else:
self.resample = lambda x: x
def cut_leading_and_trailing_silence(self, audio):
"""
https://github.com/snakers4/silero-vad
"""
with torch.inference_mode():
speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr)
try:
result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
return result
except IndexError:
print("Audio might be too short to cut silences from front and back.")
return audio
def normalize_loudness(self, audio):
"""
normalize the amplitudes according to
their decibels, so this should turn any
signal with different magnitudes into
the same magnitude by analysing loudness
"""
try:
loudness = self.meter.integrated_loudness(audio)
except ValueError:
# if the audio is too short, a value error will arise
return audio
loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
peak = numpy.amax(numpy.abs(loud_normed))
peak_normed = numpy.divide(loud_normed, peak)
return peak_normed
def normalize_audio(self, audio):
"""
one function to apply them all in an
order that makes sense.
"""
if self.do_loudnorm:
audio = self.normalize_loudness(audio)
audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
audio = self.resample(audio)
if self.cut_silence:
audio = self.cut_leading_and_trailing_silence(audio)
return audio
def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None):
"""
explicit_sampling_rate is for when
normalization has already been applied
and that included resampling. No way
to detect the current input_sr of the incoming
audio
"""
if type(audio) != torch.tensor and type(audio) != torch.Tensor:
audio = torch.tensor(audio, device=self.device)
if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr:
return self.wave_to_spectrogram(audio.float())
else:
if explicit_sampling_rate != self.input_sr:
print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.")
self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device)
self.input_sr = explicit_sampling_rate
audio = self.resample(audio.float())
return self.wave_to_spectrogram(audio)
class LogMelSpec(torch.nn.Module):
def __init__(self, sr, *args, **kwargs):
super().__init__(*args, **kwargs)
self.spec = MelSpectrogram(sample_rate=sr,
n_fft=1024,
win_length=1024,
hop_length=256,
f_min=40.0,
f_max=sr // 2,
pad=0,
n_mels=128,
power=2.0,
normalized=False,
center=True,
pad_mode='reflect',
mel_scale='htk')
def forward(self, audio):
melspec = self.spec(audio.float())
zero_mask = melspec == 0
melspec[zero_mask] = 1e-8
logmelspec = torch.log10(melspec)
return logmelspec
if __name__ == '__main__':
import soundfile
wav, sr = soundfile.read("../audios/ad00_0004.wav")
ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6))
import librosa.display as lbd
lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(),
ax=ax,
sr=16000,
cmap='GnBu',
y_axis='features',
x_axis=None,
hop_length=256)
plt.show()