File size: 2,091 Bytes
cedaa57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import sys
import tkinter as tk


window = tk.Tk()
window.title(string="Model Merger")
tk.Label(text = "Model Merger",font=("Arial",25)).pack()
tk.Label(text = "GUI by antrobot1234").pack()

frame1 = tk.Frame()
frame2 = tk.Frame()
frame3 = tk.Frame()

frameSlider = tk.Frame()
frameButton = tk.Frame()

tk.Label(frame1,text = "File 1:").pack(side="left")
file1text = tk.Entry(frame1,width=40)
file1text.pack(side="left")

tk.Label(frame2,text = "File 2:").pack(side="left")
file2text = tk.Entry(frame2,width=40)
file2text.pack(side="left")

tk.Label(frame3,text = "File Out:").pack(side="left")
fileOtext = tk.Entry(frame3,width=38)
fileOtext.pack(side="left")

tk.Label(frameSlider,text = "Weight of file 1").pack(side="left")
scale = tk.Scale(frameSlider,from_=0, to=100,orient="horizontal",tickinterval=10,length=450)
scale.pack(side="left")



goButton = tk.Button(frameButton,text="RUN",height=2,width=20,bg="green")
def merge(file1,file2,out,a):
    alpha = (a)/100
    if not(file1.endswith(".ckpt")):
        file1 += ".ckpt"
    if not(file2.endswith(".ckpt")):
        file2 += ".ckpt"
    if not(out.endswith(".ckpt")):
        out += ".ckpt"
    #Load Models
    model_0 = torch.load(file1)
    model_1 = torch.load(file2)
    theta_0 = model_0['state_dict']
    theta_1 = model_1['state_dict']

    for key in theta_0.keys():
        if 'model' in key and key in theta_1:
            theta_0[key] = (alpha) * theta_0[key] + (1-alpha) * theta_1[key]

    goButton.config(bg="red",text="RUNNING...\n(STAGE 2)")
    window.update()

    for key in theta_1.keys():
        if 'model' in key and key not in theta_0:
            theta_0[key] = theta_1[key]
    torch.save(model_0, out)
    

def handleClick(event):
    goButton.config(bg="red",text="RUNNING...\n(STAGE 1)")
    window.update()
    merge(file1text.get(),file2text.get(),fileOtext.get(),scale.get())
    goButton.config(bg="green",text="RUN")
    
goButton.pack()
goButton.bind("<Button-1>",handleClick)


frame1.pack()
frame2.pack()
frame3.pack()
frameSlider.pack()
frameButton.pack()

window.mainloop()