Spaces:
Running
Running
import os | |
import json | |
import torch | |
import numpy as np | |
import audioldm.hifigan as hifigan | |
HIFIGAN_16K_64 = { | |
"resblock": "1", | |
"num_gpus": 6, | |
"batch_size": 16, | |
"learning_rate": 0.0002, | |
"adam_b1": 0.8, | |
"adam_b2": 0.99, | |
"lr_decay": 0.999, | |
"seed": 1234, | |
"upsample_rates": [5, 4, 2, 2, 2], | |
"upsample_kernel_sizes": [16, 16, 8, 4, 4], | |
"upsample_initial_channel": 1024, | |
"resblock_kernel_sizes": [3, 7, 11], | |
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], | |
"segment_size": 8192, | |
"num_mels": 64, | |
"num_freq": 1025, | |
"n_fft": 1024, | |
"hop_size": 160, | |
"win_size": 1024, | |
"sampling_rate": 16000, | |
"fmin": 0, | |
"fmax": 8000, | |
"fmax_for_loss": None, | |
"num_workers": 4, | |
"dist_config": { | |
"dist_backend": "nccl", | |
"dist_url": "tcp://localhost:54321", | |
"world_size": 1, | |
}, | |
} | |
def get_available_checkpoint_keys(model, ckpt): | |
print("==> Attemp to reload from %s" % ckpt) | |
state_dict = torch.load(ckpt)["state_dict"] | |
current_state_dict = model.state_dict() | |
new_state_dict = {} | |
for k in state_dict.keys(): | |
if ( | |
k in current_state_dict.keys() | |
and current_state_dict[k].size() == state_dict[k].size() | |
): | |
new_state_dict[k] = state_dict[k] | |
else: | |
print("==> WARNING: Skipping %s" % k) | |
print( | |
"%s out of %s keys are matched" | |
% (len(new_state_dict.keys()), len(state_dict.keys())) | |
) | |
return new_state_dict | |
def get_param_num(model): | |
num_param = sum(param.numel() for param in model.parameters()) | |
return num_param | |
def get_vocoder(config, device): | |
config = hifigan.AttrDict(HIFIGAN_16K_64) | |
vocoder = hifigan.Generator(config) | |
vocoder.eval() | |
vocoder.remove_weight_norm() | |
vocoder.to(device) | |
return vocoder | |
def vocoder_infer(mels, vocoder, lengths=None): | |
with torch.no_grad(): | |
wavs = vocoder(mels).squeeze(1) | |
wavs = (wavs.cpu().numpy() * 32768).astype("int16") | |
if lengths is not None: | |
wavs = wavs[:, :lengths] | |
return wavs | |