ZeroCool94 commited on
Commit
cedaa57
1 Parent(s): 3eb2147

Upload merge.py

Browse files
Files changed (1) hide show
  1. merge.py +80 -0
merge.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import tkinter as tk
4
+
5
+
6
+ window = tk.Tk()
7
+ window.title(string="Model Merger")
8
+ tk.Label(text = "Model Merger",font=("Arial",25)).pack()
9
+ tk.Label(text = "GUI by antrobot1234").pack()
10
+
11
+ frame1 = tk.Frame()
12
+ frame2 = tk.Frame()
13
+ frame3 = tk.Frame()
14
+
15
+ frameSlider = tk.Frame()
16
+ frameButton = tk.Frame()
17
+
18
+ tk.Label(frame1,text = "File 1:").pack(side="left")
19
+ file1text = tk.Entry(frame1,width=40)
20
+ file1text.pack(side="left")
21
+
22
+ tk.Label(frame2,text = "File 2:").pack(side="left")
23
+ file2text = tk.Entry(frame2,width=40)
24
+ file2text.pack(side="left")
25
+
26
+ tk.Label(frame3,text = "File Out:").pack(side="left")
27
+ fileOtext = tk.Entry(frame3,width=38)
28
+ fileOtext.pack(side="left")
29
+
30
+ tk.Label(frameSlider,text = "Weight of file 1").pack(side="left")
31
+ scale = tk.Scale(frameSlider,from_=0, to=100,orient="horizontal",tickinterval=10,length=450)
32
+ scale.pack(side="left")
33
+
34
+
35
+
36
+ goButton = tk.Button(frameButton,text="RUN",height=2,width=20,bg="green")
37
+ def merge(file1,file2,out,a):
38
+ alpha = (a)/100
39
+ if not(file1.endswith(".ckpt")):
40
+ file1 += ".ckpt"
41
+ if not(file2.endswith(".ckpt")):
42
+ file2 += ".ckpt"
43
+ if not(out.endswith(".ckpt")):
44
+ out += ".ckpt"
45
+ #Load Models
46
+ model_0 = torch.load(file1)
47
+ model_1 = torch.load(file2)
48
+ theta_0 = model_0['state_dict']
49
+ theta_1 = model_1['state_dict']
50
+
51
+ for key in theta_0.keys():
52
+ if 'model' in key and key in theta_1:
53
+ theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]
54
+
55
+ goButton.config(bg="red",text="RUNNING...\n(STAGE 2)")
56
+ window.update()
57
+
58
+ for key in theta_1.keys():
59
+ if 'model' in key and key not in theta_0:
60
+ theta_0[key] = theta_1[key]
61
+ torch.save(model_0, out)
62
+
63
+
64
+ def handleClick(event):
65
+ goButton.config(bg="red",text="RUNNING...\n(STAGE 1)")
66
+ window.update()
67
+ merge(file1text.get(),file2text.get(),fileOtext.get(),scale.get())
68
+ goButton.config(bg="green",text="RUN")
69
+
70
+ goButton.pack()
71
+ goButton.bind("<Button-1>",handleClick)
72
+
73
+
74
+ frame1.pack()
75
+ frame2.pack()
76
+ frame3.pack()
77
+ frameSlider.pack()
78
+ frameButton.pack()
79
+
80
+ window.mainloop()