import numpy
import pyloudnorm as pyln
import torch
from torchaudio.transforms import MelSpectrogram
from torchaudio.transforms import Resample


class AudioPreprocessor:

    def __init__(self, input_sr, output_sr=None, cut_silence=False, do_loudnorm=False, device="cpu"):
        """
        The parameters are by default set up to do well
        on a 16kHz signal. A different sampling rate may
        require different hop_length and n_fft (e.g.
        doubling frequency --> doubling hop_length and
        doubling n_fft)
        """
        self.cut_silence = cut_silence
        self.do_loudnorm = do_loudnorm
        self.device = device
        self.input_sr = input_sr
        self.output_sr = output_sr
        self.meter = pyln.Meter(input_sr)
        self.final_sr = input_sr
        self.wave_to_spectrogram = LogMelSpec(output_sr if output_sr is not None else input_sr).to(device)
        if cut_silence:
            torch.hub._validate_not_a_forked_repo = lambda a, b, c: True  # torch 1.9 has a bug in the hub loading, this is a workaround
            # careful: assumes 16kHz or 8kHz audio
            self.silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                                      model='silero_vad',
                                                      force_reload=False,
                                                      onnx=False,
                                                      verbose=False)
            (self.get_speech_timestamps,
             self.save_audio,
             self.read_audio,
             self.VADIterator,
             self.collect_chunks) = utils
            torch.set_grad_enabled(True)  # finding this issue was very infuriating: silero sets
            # this to false globally during model loading rather than using inference mode or no_grad
            self.silero_model = self.silero_model.to(self.device)
        if output_sr is not None and output_sr != input_sr:
            self.resample = Resample(orig_freq=input_sr, new_freq=output_sr).to(self.device)
            self.final_sr = output_sr
        else:
            self.resample = lambda x: x

    def cut_leading_and_trailing_silence(self, audio):
        """
        https://github.com/snakers4/silero-vad
        """
        with torch.inference_mode():
            speech_timestamps = self.get_speech_timestamps(audio, self.silero_model, sampling_rate=self.final_sr)
        try:
            result = audio[speech_timestamps[0]['start']:speech_timestamps[-1]['end']]
            return result
        except IndexError:
            print("Audio might be too short to cut silences from front and back.")
        return audio

    def normalize_loudness(self, audio):
        """
        normalize the amplitudes according to
        their decibels, so this should turn any
        signal with different magnitudes into
        the same magnitude by analysing loudness
        """
        try:
            loudness = self.meter.integrated_loudness(audio)
        except ValueError:
            # if the audio is too short, a value error will arise
            return audio
        loud_normed = pyln.normalize.loudness(audio, loudness, -30.0)
        peak = numpy.amax(numpy.abs(loud_normed))
        peak_normed = numpy.divide(loud_normed, peak)
        return peak_normed

    def normalize_audio(self, audio):
        """
        one function to apply them all in an
        order that makes sense.
        """
        if self.do_loudnorm:
            audio = self.normalize_loudness(audio)
        audio = torch.tensor(audio, device=self.device, dtype=torch.float32)
        audio = self.resample(audio)
        if self.cut_silence:
            audio = self.cut_leading_and_trailing_silence(audio)
        return audio

    def audio_to_mel_spec_tensor(self, audio, normalize=False, explicit_sampling_rate=None):
        """
        explicit_sampling_rate is for when
        normalization has already been applied
        and that included resampling. No way
        to detect the current input_sr of the incoming
        audio
        """
        if type(audio) != torch.tensor and type(audio) != torch.Tensor:
            audio = torch.tensor(audio, device=self.device)
        if explicit_sampling_rate is None or explicit_sampling_rate == self.output_sr:
            return self.wave_to_spectrogram(audio.float())
        else:
            if explicit_sampling_rate != self.input_sr:
                print("WARNING: different sampling rate used, this will be very slow if it happens often. Consider creating a dedicated audio processor.")
                self.resample = Resample(orig_freq=explicit_sampling_rate, new_freq=self.output_sr).to(self.device)
                self.input_sr = explicit_sampling_rate
            audio = self.resample(audio.float())
            return self.wave_to_spectrogram(audio)


class LogMelSpec(torch.nn.Module):
    def __init__(self, sr, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.spec = MelSpectrogram(sample_rate=sr,
                                   n_fft=1024,
                                   win_length=1024,
                                   hop_length=256,
                                   f_min=40.0,
                                   f_max=sr // 2,
                                   pad=0,
                                   n_mels=128,
                                   power=2.0,
                                   normalized=False,
                                   center=True,
                                   pad_mode='reflect',
                                   mel_scale='htk')

    def forward(self, audio):
        melspec = self.spec(audio.float())
        zero_mask = melspec == 0
        melspec[zero_mask] = 1e-8
        logmelspec = torch.log10(melspec)
        return logmelspec


if __name__ == '__main__':
    import soundfile

    wav, sr = soundfile.read("../audios/ad00_0004.wav")
    ap = AudioPreprocessor(input_sr=sr, output_sr=16000, cut_silence=True)
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 6))
    import librosa.display as lbd

    lbd.specshow(ap.audio_to_mel_spec_tensor(wav).cpu().numpy(),
                 ax=ax,
                 sr=16000,
                 cmap='GnBu',
                 y_axis='features',
                 x_axis=None,
                 hop_length=256)
    plt.show()