Spaces:
Runtime error
Runtime error
from vocoder.models.fatchord_version import WaveRNN | |
from vocoder import hparams as hp | |
from scipy.fft import rfft, rfftfreq | |
from scipy import signal | |
from denoiser.pretrained import master64 | |
import librosa | |
import numpy as np | |
import torch | |
import torchaudio | |
import noisereduce as nr | |
_model = None # type: WaveRNN | |
def load_model(weights_fpath, verbose=True): | |
global _model, _device | |
if verbose: | |
print("Building Wave-RNN") | |
_model = WaveRNN( | |
rnn_dims=hp.voc_rnn_dims, | |
fc_dims=hp.voc_fc_dims, | |
bits=hp.bits, | |
pad=hp.voc_pad, | |
upsample_factors=hp.voc_upsample_factors, | |
feat_dims=hp.num_mels, | |
compute_dims=hp.voc_compute_dims, | |
res_out_dims=hp.voc_res_out_dims, | |
res_blocks=hp.voc_res_blocks, | |
hop_length=hp.hop_length, | |
sample_rate=hp.sample_rate, | |
mode=hp.voc_mode | |
) | |
if torch.cuda.is_available(): | |
_model = _model.cuda() | |
_device = torch.device('cuda') | |
else: | |
_device = torch.device('cpu') | |
if verbose: | |
print("Loading model weights at %s" % weights_fpath) | |
checkpoint = torch.load(weights_fpath, _device) | |
_model.load_state_dict(checkpoint['model_state']) | |
_model.eval() | |
def is_loaded(): | |
return _model is not None | |
def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, | |
progress_callback=None, crossfade=True): | |
""" | |
Infers the waveform of a mel spectrogram output by the synthesizer (the format must match | |
that of the synthesizer!) | |
:param normalize: | |
:param batched: | |
:param target: | |
:param overlap: | |
:return: | |
""" | |
if _model is None: | |
raise Exception("Please load Wave-RNN in memory before using it") | |
if normalize: | |
mel = mel / hp.mel_max_abs_value | |
mel = torch.from_numpy(mel[None, ...]) | |
wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback, crossfade=crossfade) | |
wav = waveform_denoising(wav) | |
return wav | |
def waveform_denoising(wav): | |
prop_decrease = hp.prop_decrease_low_freq if hp.sex else hp.prop_decrease_high_freq | |
if torch.cuda.is_available(): | |
_device = torch.device('cuda') | |
else: | |
_device = torch.device('cpu') | |
model = master64().to(_device) | |
noisy=torch.from_numpy(np.array([wav])).to(_device).float() | |
estimate = model(noisy) | |
estimate = estimate * (1-hp.dry) + noisy * hp.dry | |
estimate = estimate[0].cpu().detach().numpy() | |
return nr.reduce_noise(np.squeeze(estimate), hp.sample_rate, prop_decrease=prop_decrease) | |
def get_dominant_freq(wav, name="fft"): | |
import matplotlib.pyplot as plt | |
N = len(wav) | |
fft_wav = rfft(wav) | |
fft_freq = np.real(rfftfreq(N, 1 / hp.sample_rate)) | |
fft_least_index = np.where(fft_freq >= 60)[0][0] | |
fft_max = max(fft_wav[fft_least_index: ]) | |
fft_max_index = np.where(fft_wav == fft_max)[0][0] | |
fft_max_freq = fft_freq[fft_max_index] | |
# plt.clf() | |
# plt.plot(fft_freq, fft_wav) | |
# plt.savefig(f"{name}.png", dpi=300) | |
return fft_max_freq |