|
import os |
|
import json |
|
|
|
import torch |
|
import numpy as np |
|
|
|
import hifigan |
|
from model import FastSpeech2, ScheduledOptim |
|
|
|
|
|
def get_model(args, configs, device, train=False): |
|
(preprocess_config, model_config, train_config) = configs |
|
|
|
model = FastSpeech2(preprocess_config, model_config).to(device) |
|
if args.restore_step: |
|
ckpt_path = os.path.join( |
|
train_config["path"]["ckpt_path"], |
|
"{}.pth.tar".format(args.restore_step), |
|
) |
|
ckpt = torch.load(ckpt_path, map_location=device) |
|
model.load_state_dict(ckpt["model"]) |
|
|
|
if train: |
|
scheduled_optim = ScheduledOptim( |
|
model, train_config, model_config, args.restore_step |
|
) |
|
if args.restore_step: |
|
scheduled_optim.load_state_dict(ckpt["optimizer"]) |
|
model.train() |
|
return model, scheduled_optim |
|
|
|
model.eval() |
|
model.requires_grad_ = False |
|
return model |
|
|
|
|
|
def get_param_num(model): |
|
num_param = sum(param.numel() for param in model.parameters()) |
|
return num_param |
|
|
|
|
|
def get_vocoder(config, device): |
|
name = config["vocoder"]["model"] |
|
speaker = config["vocoder"]["speaker"] |
|
|
|
if name == "MelGAN": |
|
if speaker == "LJSpeech": |
|
vocoder = torch.hub.load( |
|
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson" |
|
) |
|
elif speaker == "universal": |
|
vocoder = torch.hub.load( |
|
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker",map_location=device |
|
) |
|
vocoder.mel2wav.eval() |
|
vocoder.mel2wav.to(device) |
|
elif name == "HiFi-GAN": |
|
with open("hifigan/config.json", "r") as f: |
|
config = json.load(f) |
|
config = hifigan.AttrDict(config) |
|
vocoder = hifigan.Generator(config) |
|
if speaker == "LJSpeech": |
|
ckpt = torch.load("hifigan/generator_LJSpeech.pth.tar",map_location=device) |
|
elif speaker == "universal": |
|
ckpt = torch.load("hifigan/generator_universal.pth.tar",map_location=device) |
|
vocoder.load_state_dict(ckpt["generator"]) |
|
vocoder.eval() |
|
vocoder.remove_weight_norm() |
|
vocoder.to(device) |
|
|
|
return vocoder |
|
|
|
|
|
def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None): |
|
name = model_config["vocoder"]["model"] |
|
with torch.no_grad(): |
|
if name == "MelGAN": |
|
wavs = vocoder.inverse(mels / np.log(10)) |
|
elif name == "HiFi-GAN": |
|
wavs = vocoder(mels).squeeze(1) |
|
|
|
wavs = ( |
|
wavs.cpu().numpy() |
|
* preprocess_config["preprocessing"]["audio"]["max_wav_value"] |
|
).astype("int16") |
|
wavs = [wav for wav in wavs] |
|
|
|
for i in range(len(mels)): |
|
if lengths is not None: |
|
wavs[i] = wavs[i][: lengths[i]] |
|
|
|
return wavs |
|
|