import os import torch import hashlib import datetime from collections import OrderedDict def replace_keys_in_dict(d, old_key_part, new_key_part): if isinstance(d, OrderedDict): updated_dict = OrderedDict() else: updated_dict = {} for key, value in d.items(): new_key = key.replace(old_key_part, new_key_part) if isinstance(value, dict): value = replace_keys_in_dict(value, old_key_part, new_key_part) updated_dict[new_key] = value return updated_dict def extract_model(ckpt, sr, if_f0, name, model_dir, epoch, step, version, hps): try: print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") pth_file = f"{name}_{epoch}e_{step}s.pth" pth_file_old_version_path = os.path.join( model_dir, f"{pth_file}_old_version.pth" ) opt = OrderedDict( weight={ key: value.half() for key, value in ckpt.items() if "enc_q" not in key } ) opt["config"] = [ hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sampling_rate, ] opt["epoch"] = epoch opt["step"] = step opt["sr"] = sr opt["f0"] = if_f0 opt["version"] = version opt["creation_date"] = datetime.datetime.now().isoformat() hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}" model_hash = hashlib.sha256(hash_input.encode()).hexdigest() opt["model_hash"] = model_hash torch.save(opt, model_dir) model = torch.load(model_dir, map_location=torch.device("cpu")) torch.save( replace_keys_in_dict( replace_keys_in_dict( model, ".parametrizations.weight.original1", ".weight_v" ), ".parametrizations.weight.original0", ".weight_g", ), pth_file_old_version_path, ) os.remove(model_dir) os.rename(pth_file_old_version_path, model_dir) except Exception as error: print(error)