sovits-test / svc_merge.py
atsushieee's picture
Upload folder using huggingface_hub
9791162
import os
import torch
import argparse
import collections
def load_model(checkpoint_path):
assert os.path.isfile(checkpoint_path)
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
saved_state_dict = checkpoint_dict["model_g"]
return saved_state_dict
def save_model(state_dict, checkpoint_path):
torch.save({'model_g': state_dict}, checkpoint_path)
def average_model(model_list):
model_keys = list(model_list[0].keys())
model_average = collections.OrderedDict()
for key in model_keys:
key_sum = 0
for i in range(len(model_list)):
key_sum = (key_sum + model_list[i][key])
model_average[key] = torch.div(key_sum, float(len(model_list)))
return model_average
# ss_list = []
# ss_list.append(s1)
# ss_list.append(s2)
# ss_merge = average_model(ss_list)
def merge_model(model1, model2, rate):
model_keys = model1.keys()
model_merge = collections.OrderedDict()
for key in model_keys:
key_merge = rate * model1[key] + (1 - rate) * model2[key]
model_merge[key] = key_merge
return model_merge
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m1', '--model1', type=str, required=True)
parser.add_argument('-m2', '--model2', type=str, required=True)
parser.add_argument('-r1', '--rate', type=float, required=True)
args = parser.parse_args()
print(args.model1)
print(args.model2)
print(args.rate)
assert args.rate > 0 and args.rate < 1, f"{args.rate} should be in range (0, 1)"
s1 = load_model(args.model1)
s2 = load_model(args.model2)
merge = merge_model(s1, s2, args.rate)
save_model(merge, "sovits5.0_merge.pth")