import torch | |
model_1 = torch.load('model_1.ckpt', map_location='cpu') | |
model_2 = torch.load('model_2.ckpt', map_location='cpu') | |
model_3 = torch.load('model_3.ckpt', map_location='cpu') | |
# Combine the models | |
fused_weights = {} | |
for key in model_1.keys(): | |
fused_weights[key] = 0.5 * model_1[key] + 0.25 * model_2[key] + 0.25 * model_3[key] | |
# Save the fused model | |
torch.save(fused_weights, 'fused_model.ckpt') |