Serhiy Stetskovych
Initial commit
d48b9e1
raw
history blame
5.16 kB
import torch
import audiosr.hifigan as hifigan
def get_vocoder_config():
return {
"resblock": "1",
"num_gpus": 6,
"batch_size": 16,
"learning_rate": 0.0002,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"upsample_initial_channel": 1024,
"resblock_kernel_sizes": [3, 7, 11],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"segment_size": 8192,
"num_mels": 64,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 160,
"win_size": 1024,
"sampling_rate": 16000,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": None,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1,
},
}
def get_vocoder_config_48k():
return {
"resblock": "1",
"num_gpus": 8,
"batch_size": 128,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [6, 5, 4, 2, 2],
"upsample_kernel_sizes": [12, 10, 8, 4, 4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3, 7, 11, 15],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
"segment_size": 15360,
"num_mels": 256,
"n_fft": 2048,
"hop_size": 480,
"win_size": 2048,
"sampling_rate": 48000,
"fmin": 20,
"fmax": 24000,
"fmax_for_loss": None,
"num_workers": 8,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:18273",
"world_size": 1,
},
}
def get_available_checkpoint_keys(model, ckpt):
state_dict = torch.load(ckpt)["state_dict"]
current_state_dict = model.state_dict()
new_state_dict = {}
for k in state_dict.keys():
if (
k in current_state_dict.keys()
and current_state_dict[k].size() == state_dict[k].size()
):
new_state_dict[k] = state_dict[k]
else:
print("==> WARNING: Skipping %s" % k)
print(
"%s out of %s keys are matched"
% (len(new_state_dict.keys()), len(state_dict.keys()))
)
return new_state_dict
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def torch_version_orig_mod_remove(state_dict):
new_state_dict = {}
new_state_dict["generator"] = {}
for key in state_dict["generator"].keys():
if "_orig_mod." in key:
new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
"generator"
][key]
else:
new_state_dict["generator"][key] = state_dict["generator"][key]
return new_state_dict
def get_vocoder(config, device, mel_bins):
name = "HiFi-GAN"
speaker = ""
if name == "MelGAN":
if speaker == "LJSpeech":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
)
elif speaker == "universal":
vocoder = torch.hub.load(
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
)
vocoder.mel2wav.eval()
vocoder.mel2wav.to(device)
elif name == "HiFi-GAN":
if mel_bins == 64:
config = get_vocoder_config()
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
else:
config = get_vocoder_config_48k()
config = hifigan.AttrDict(config)
vocoder = hifigan.Generator_old(config)
# print("Load hifigan/g_01080000")
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
# ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
# ckpt = torch_version_orig_mod_remove(ckpt)
# vocoder.load_state_dict(ckpt["generator"])
vocoder.eval()
vocoder.remove_weight_norm()
vocoder.to(device)
return vocoder
def vocoder_infer(mels, vocoder, lengths=None):
with torch.no_grad():
wavs = vocoder(mels).squeeze(1)
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
if lengths is not None:
wavs = wavs[:, :lengths]
# wavs = [wav for wav in wavs]
# for i in range(len(mels)):
# if lengths is not None:
# wavs[i] = wavs[i][: lengths[i]]
return wavs