import os import json import torch import numpy as np import hifigan from model import FastSpeech2, ScheduledOptim def get_model(args, configs, device, train=False): (preprocess_config, model_config, train_config) = configs model = FastSpeech2(preprocess_config, model_config).to(device) if args.restore_step: ckpt_path = os.path.join( train_config["path"]["ckpt_path"], "{}.pth.tar".format(args.restore_step), ) ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt["model"]) if train: scheduled_optim = ScheduledOptim( model, train_config, model_config, args.restore_step ) if args.restore_step: scheduled_optim.load_state_dict(ckpt["optimizer"]) model.train() return model, scheduled_optim model.eval() model.requires_grad_ = False return model def get_param_num(model): num_param = sum(param.numel() for param in model.parameters()) return num_param def get_vocoder(config, device): name = config["vocoder"]["model"] speaker = config["vocoder"]["speaker"] if name == "MelGAN": if speaker == "LJSpeech": vocoder = torch.hub.load( "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" ) elif speaker == "universal": vocoder = torch.hub.load( "descriptinc/melgan-neurips", "load_melgan", "multi_speaker",map_location=device ) vocoder.mel2wav.eval() vocoder.mel2wav.to(device) elif name == "HiFi-GAN": with open("hifigan/config.json", "r") as f: config = json.load(f) config = hifigan.AttrDict(config) vocoder = hifigan.Generator(config) if speaker == "LJSpeech": ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar",map_location=device) elif speaker == "universal": ckpt = torch.load("hifigan/generator_universal.pth.tar",map_location=device) vocoder.load_state_dict(ckpt["generator"]) vocoder.eval() vocoder.remove_weight_norm() vocoder.to(device) return vocoder def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): name = model_config["vocoder"]["model"] with torch.no_grad(): if name == "MelGAN": wavs = vocoder.inverse(mels / np.log(10)) elif name == "HiFi-GAN": wavs = vocoder(mels).squeeze(1) wavs = ( wavs.cpu().numpy() * preprocess_config["preprocessing"]["audio"]["max_wav_value"] ).astype("int16") wavs = [wav for wav in wavs] for i in range(len(mels)): if lengths is not None: wavs[i] = wavs[i][: lengths[i]] return wavs