import os import json import torch import numpy as np import audioldm.hifigan as hifigan HIFIGAN_16K_64 = { "resblock": "1", "num_gpus": 6, "batch_size": 16, "learning_rate": 0.0002, "adam_b1": 0.8, "adam_b2": 0.99, "lr_decay": 0.999, "seed": 1234, "upsample_rates": [5, 4, 2, 2, 2], "upsample_kernel_sizes": [16, 16, 8, 4, 4], "upsample_initial_channel": 1024, "resblock_kernel_sizes": [3, 7, 11], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "segment_size": 8192, "num_mels": 64, "num_freq": 1025, "n_fft": 1024, "hop_size": 160, "win_size": 1024, "sampling_rate": 16000, "fmin": 0, "fmax": 8000, "fmax_for_loss": None, "num_workers": 4, "dist_config": { "dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1, }, } def get_available_checkpoint_keys(model, ckpt): print("==> Attemp to reload from %s" % ckpt) state_dict = torch.load(ckpt)["state_dict"] current_state_dict = model.state_dict() new_state_dict = {} for k in state_dict.keys(): if ( k in current_state_dict.keys() and current_state_dict[k].size() == state_dict[k].size() ): new_state_dict[k] = state_dict[k] else: print("==> WARNING: Skipping %s" % k) print( "%s out of %s keys are matched" % (len(new_state_dict.keys()), len(state_dict.keys())) ) return new_state_dict def get_param_num(model): num_param = sum(param.numel() for param in model.parameters()) return num_param def get_vocoder(config, device): config = hifigan.AttrDict(HIFIGAN_16K_64) vocoder = hifigan.Generator(config) vocoder.eval() vocoder.remove_weight_norm() vocoder.to(device) return vocoder def vocoder_infer(mels, vocoder, lengths=None): vocoder.eval() with torch.no_grad(): wavs = vocoder(mels).squeeze(1) wavs = (wavs.cpu().numpy() * 32768).astype("int16") if lengths is not None: wavs = wavs[:, :lengths] return wavs