Sucial's picture
Upload model_fusion.py
7496225 verified
raw
history blame contribute delete
422 Bytes
import torch
model_1 = torch.load('model_1.ckpt', map_location='cpu')
model_2 = torch.load('model_2.ckpt', map_location='cpu')
model_3 = torch.load('model_3.ckpt', map_location='cpu')
# Combine the models
fused_weights = {}
for key in model_1.keys():
fused_weights[key] = 0.5 * model_1[key] + 0.25 * model_2[key] + 0.25 * model_3[key]
# Save the fused model
torch.save(fused_weights, 'fused_model.ckpt')