Spaces:
Running
Running
import os | |
import torch | |
import argparse | |
import collections | |
def load_model(checkpoint_path): | |
assert os.path.isfile(checkpoint_path) | |
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") | |
saved_state_dict = checkpoint_dict["model_g"] | |
return saved_state_dict | |
def save_model(state_dict, checkpoint_path): | |
torch.save({'model_g': state_dict}, checkpoint_path) | |
def average_model(model_list): | |
model_keys = list(model_list[0].keys()) | |
model_average = collections.OrderedDict() | |
for key in model_keys: | |
key_sum = 0 | |
for i in range(len(model_list)): | |
key_sum = (key_sum + model_list[i][key]) | |
model_average[key] = torch.div(key_sum, float(len(model_list))) | |
return model_average | |
# ss_list = [] | |
# ss_list.append(s1) | |
# ss_list.append(s2) | |
# ss_merge = average_model(ss_list) | |
def merge_model(model1, model2, rate): | |
model_keys = model1.keys() | |
model_merge = collections.OrderedDict() | |
for key in model_keys: | |
key_merge = rate * model1[key] + (1 - rate) * model2[key] | |
model_merge[key] = key_merge | |
return model_merge | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-m1', '--model1', type=str, required=True) | |
parser.add_argument('-m2', '--model2', type=str, required=True) | |
parser.add_argument('-r1', '--rate', type=float, required=True) | |
args = parser.parse_args() | |
print(args.model1) | |
print(args.model2) | |
print(args.rate) | |
assert args.rate > 0 and args.rate < 1, f"{args.rate} should be in range (0, 1)" | |
s1 = load_model(args.model1) | |
s2 = load_model(args.model2) | |
merge = merge_model(s1, s2, args.rate) | |
save_model(merge, "sovits5.0_merge.pth") | |