Upload model_fusion.py
Browse files- scripts/model_fusion.py +13 -0
scripts/model_fusion.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
model_1 = torch.load('model_1.ckpt', map_location='cpu')
|
4 |
+
model_2 = torch.load('model_2.ckpt', map_location='cpu')
|
5 |
+
model_3 = torch.load('model_3.ckpt', map_location='cpu')
|
6 |
+
|
7 |
+
# Combine the models
|
8 |
+
fused_weights = {}
|
9 |
+
for key in model_1.keys():
|
10 |
+
fused_weights[key] = 0.5 * model_1[key] + 0.25 * model_2[key] + 0.25 * model_3[key]
|
11 |
+
|
12 |
+
# Save the fused model
|
13 |
+
torch.save(fused_weights, 'fused_model.ckpt')
|