Spaces:
Sleeping
Sleeping
import os | |
import json | |
import torch | |
import numpy as np | |
import hifigan | |
from model import FastSpeech2, ScheduledOptim | |
def get_model(args, configs, device, train=False): | |
(preprocess_config, model_config, train_config) = configs | |
model = FastSpeech2(preprocess_config, model_config).to(device) | |
if args.restore_step: | |
ckpt_path = os.path.join( | |
train_config["path"]["ckpt_path"], | |
"{}.pth.tar".format(args.restore_step), | |
) | |
ckpt = torch.load(ckpt_path, map_location=device) | |
model.load_state_dict(ckpt["model"]) | |
if train: | |
scheduled_optim = ScheduledOptim( | |
model, train_config, model_config, args.restore_step | |
) | |
if args.restore_step: | |
scheduled_optim.load_state_dict(ckpt["optimizer"]) | |
model.train() | |
return model, scheduled_optim | |
model.eval() | |
model.requires_grad_ = False | |
return model | |
def get_param_num(model): | |
num_param = sum(param.numel() for param in model.parameters()) | |
return num_param | |
def get_vocoder(config, device): | |
name = config["vocoder"]["model"] | |
speaker = config["vocoder"]["speaker"] | |
if name == "MelGAN": | |
if speaker == "LJSpeech": | |
vocoder = torch.hub.load( | |
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson" | |
) | |
elif speaker == "universal": | |
vocoder = torch.hub.load( | |
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker",map_location=device | |
) | |
vocoder.mel2wav.eval() | |
vocoder.mel2wav.to(device) | |
elif name == "HiFi-GAN": | |
with open("hifigan/config.json", "r") as f: | |
config = json.load(f) | |
config = hifigan.AttrDict(config) | |
vocoder = hifigan.Generator(config) | |
if speaker == "LJSpeech": | |
ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar",map_location=device) | |
elif speaker == "universal": | |
ckpt = torch.load("hifigan/generator_universal.pth.tar",map_location=device) | |
vocoder.load_state_dict(ckpt["generator"]) | |
vocoder.eval() | |
vocoder.remove_weight_norm() | |
vocoder.to(device) | |
return vocoder | |
def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): | |
name = model_config["vocoder"]["model"] | |
with torch.no_grad(): | |
if name == "MelGAN": | |
wavs = vocoder.inverse(mels / np.log(10)) | |
elif name == "HiFi-GAN": | |
wavs = vocoder(mels).squeeze(1) | |
wavs = ( | |
wavs.cpu().numpy() | |
* preprocess_config["preprocessing"]["audio"]["max_wav_value"] | |
).astype("int16") | |
wavs = [wav for wav in wavs] | |
for i in range(len(mels)): | |
if lengths is not None: | |
wavs[i] = wavs[i][: lengths[i]] | |
return wavs | |