CelebChat / rtvc /vocoder /inference.py
lhzstar
initial commits
6bc94ac
raw
history blame
3.1 kB
from vocoder.models.fatchord_version import WaveRNN
from vocoder import hparams as hp
from scipy.fft import rfft, rfftfreq
from scipy import signal
from denoiser.pretrained import master64
import librosa
import numpy as np
import torch
import torchaudio
import noisereduce as nr
_model = None # type: WaveRNN
def load_model(weights_fpath, verbose=True):
global _model, _device
if verbose:
print("Building Wave-RNN")
_model = WaveRNN(
rnn_dims=hp.voc_rnn_dims,
fc_dims=hp.voc_fc_dims,
bits=hp.bits,
pad=hp.voc_pad,
upsample_factors=hp.voc_upsample_factors,
feat_dims=hp.num_mels,
compute_dims=hp.voc_compute_dims,
res_out_dims=hp.voc_res_out_dims,
res_blocks=hp.voc_res_blocks,
hop_length=hp.hop_length,
sample_rate=hp.sample_rate,
mode=hp.voc_mode
)
if torch.cuda.is_available():
_model = _model.cuda()
_device = torch.device('cuda')
else:
_device = torch.device('cpu')
if verbose:
print("Loading model weights at %s" % weights_fpath)
checkpoint = torch.load(weights_fpath, _device)
_model.load_state_dict(checkpoint['model_state'])
_model.eval()
def is_loaded():
return _model is not None
def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
progress_callback=None, crossfade=True):
"""
Infers the waveform of a mel spectrogram output by the synthesizer (the format must match
that of the synthesizer!)
:param normalize:
:param batched:
:param target:
:param overlap:
:return:
"""
if _model is None:
raise Exception("Please load Wave-RNN in memory before using it")
if normalize:
mel = mel / hp.mel_max_abs_value
mel = torch.from_numpy(mel[None, ...])
wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback, crossfade=crossfade)
wav = waveform_denoising(wav)
return wav
def waveform_denoising(wav):
prop_decrease = hp.prop_decrease_low_freq if hp.sex else hp.prop_decrease_high_freq
if torch.cuda.is_available():
_device = torch.device('cuda')
else:
_device = torch.device('cpu')
model = master64().to(_device)
noisy=torch.from_numpy(np.array([wav])).to(_device).float()
estimate = model(noisy)
estimate = estimate * (1-hp.dry) + noisy * hp.dry
estimate = estimate[0].cpu().detach().numpy()
return nr.reduce_noise(np.squeeze(estimate), hp.sample_rate, prop_decrease=prop_decrease)
def get_dominant_freq(wav, name="fft"):
import matplotlib.pyplot as plt
N = len(wav)
fft_wav = rfft(wav)
fft_freq = np.real(rfftfreq(N, 1 / hp.sample_rate))
fft_least_index = np.where(fft_freq >= 60)[0][0]
fft_max = max(fft_wav[fft_least_index: ])
fft_max_index = np.where(fft_wav == fft_max)[0][0]
fft_max_freq = fft_freq[fft_max_index]
# plt.clf()
# plt.plot(fft_freq, fft_wav)
# plt.savefig(f"{name}.png", dpi=300)
return fft_max_freq