|
import os |
|
import torch |
|
import hashlib |
|
import datetime |
|
from collections import OrderedDict |
|
|
|
|
|
def replace_keys_in_dict(d, old_key_part, new_key_part): |
|
if isinstance(d, OrderedDict): |
|
updated_dict = OrderedDict() |
|
else: |
|
updated_dict = {} |
|
for key, value in d.items(): |
|
new_key = key.replace(old_key_part, new_key_part) |
|
if isinstance(value, dict): |
|
value = replace_keys_in_dict(value, old_key_part, new_key_part) |
|
updated_dict[new_key] = value |
|
return updated_dict |
|
|
|
|
|
def extract_model(ckpt, sr, if_f0, name, model_dir, epoch, step, version, hps): |
|
try: |
|
print(f"Saved model '{model_dir}' (epoch {epoch} and step {step})") |
|
pth_file = f"{name}_{epoch}e_{step}s.pth" |
|
pth_file_old_version_path = os.path.join( |
|
model_dir, f"{pth_file}_old_version.pth" |
|
) |
|
opt = OrderedDict( |
|
weight={ |
|
key: value.half() for key, value in ckpt.items() if "enc_q" not in key |
|
} |
|
) |
|
opt["config"] = [ |
|
hps.data.filter_length // 2 + 1, |
|
32, |
|
hps.model.inter_channels, |
|
hps.model.hidden_channels, |
|
hps.model.filter_channels, |
|
hps.model.n_heads, |
|
hps.model.n_layers, |
|
hps.model.kernel_size, |
|
hps.model.p_dropout, |
|
hps.model.resblock, |
|
hps.model.resblock_kernel_sizes, |
|
hps.model.resblock_dilation_sizes, |
|
hps.model.upsample_rates, |
|
hps.model.upsample_initial_channel, |
|
hps.model.upsample_kernel_sizes, |
|
hps.model.spk_embed_dim, |
|
hps.model.gin_channels, |
|
hps.data.sampling_rate, |
|
] |
|
|
|
opt["epoch"] = epoch |
|
opt["step"] = step |
|
opt["sr"] = sr |
|
opt["f0"] = if_f0 |
|
opt["version"] = version |
|
opt["creation_date"] = datetime.datetime.now().isoformat() |
|
|
|
hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}" |
|
model_hash = hashlib.sha256(hash_input.encode()).hexdigest() |
|
opt["model_hash"] = model_hash |
|
|
|
torch.save(opt, model_dir) |
|
|
|
model = torch.load(model_dir, map_location=torch.device("cpu")) |
|
torch.save( |
|
replace_keys_in_dict( |
|
replace_keys_in_dict( |
|
model, ".parametrizations.weight.original1", ".weight_v" |
|
), |
|
".parametrizations.weight.original0", |
|
".weight_g", |
|
), |
|
pth_file_old_version_path, |
|
) |
|
os.remove(model_dir) |
|
os.rename(pth_file_old_version_path, model_dir) |
|
|
|
except Exception as error: |
|
print(error) |
|
|