Sucial commited on
Commit
7496225
1 Parent(s): 47c0989

Upload model_fusion.py

Browse files
Files changed (1) hide show
  1. scripts/model_fusion.py +13 -0
scripts/model_fusion.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ model_1 = torch.load('model_1.ckpt', map_location='cpu')
4
+ model_2 = torch.load('model_2.ckpt', map_location='cpu')
5
+ model_3 = torch.load('model_3.ckpt', map_location='cpu')
6
+
7
+ # Combine the models
8
+ fused_weights = {}
9
+ for key in model_1.keys():
10
+ fused_weights[key] = 0.5 * model_1[key] + 0.25 * model_2[key] + 0.25 * model_3[key]
11
+
12
+ # Save the fused model
13
+ torch.save(fused_weights, 'fused_model.ckpt')