|
import os |
|
import torch |
|
from collections import OrderedDict |
|
|
|
|
|
def extract(ckpt): |
|
a = ckpt["model"] |
|
opt = OrderedDict() |
|
opt["weight"] = {} |
|
for key in a.keys(): |
|
if "enc_q" in key: |
|
continue |
|
opt["weight"][key] = a[key] |
|
return opt |
|
|
|
|
|
def model_blender(name, path1, path2, ratio): |
|
try: |
|
message = f"Model {path1} and {path2} are merged with alpha {ratio}." |
|
ckpt1 = torch.load(path1, map_location="cpu") |
|
ckpt2 = torch.load(path2, map_location="cpu") |
|
cfg = ckpt1["config"] |
|
cfg_f0 = ckpt1["f0"] |
|
cfg_version = ckpt1["version"] |
|
|
|
if "model" in ckpt1: |
|
ckpt1 = extract(ckpt1) |
|
else: |
|
ckpt1 = ckpt1["weight"] |
|
if "model" in ckpt2: |
|
ckpt2 = extract(ckpt2) |
|
else: |
|
ckpt2 = ckpt2["weight"] |
|
|
|
if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())): |
|
return "Fail to merge the models. The model architectures are not the same." |
|
|
|
opt = OrderedDict() |
|
opt["weight"] = {} |
|
for key in ckpt1.keys(): |
|
if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape: |
|
min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0]) |
|
opt["weight"][key] = ( |
|
ratio * (ckpt1[key][:min_shape0].float()) |
|
+ (1 - ratio) * (ckpt2[key][:min_shape0].float()) |
|
).half() |
|
else: |
|
opt["weight"][key] = ( |
|
ratio * (ckpt1[key].float()) + (1 - ratio) * (ckpt2[key].float()) |
|
).half() |
|
|
|
opt["config"] = cfg |
|
opt["sr"] = message |
|
opt["f0"] = cfg_f0 |
|
opt["version"] = cfg_version |
|
opt["info"] = message |
|
|
|
torch.save(opt, os.path.join("logs", "%s.pth" % name)) |
|
print(message) |
|
return message, os.path.join("logs", "%s.pth" % name) |
|
except Exception as error: |
|
print(error) |
|
return error |
|
|