Spaces:
Sleeping
Sleeping
Agarwal
commited on
Commit
•
bf9ef4a
1
Parent(s):
ffa6dc5
added data
Browse files- data/mlp.py +242 -0
- data/train_profiles_mlp.py +152 -0
- data/x_pointwise.pkl +3 -0
- data/y_pointwise.pkl +3 -0
data/mlp.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.optim as optim
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from matplotlib.lines import Line2D
|
8 |
+
import math
|
9 |
+
|
10 |
+
def get_lr(optimizer):
|
11 |
+
for param_group in optimizer.param_groups:
|
12 |
+
return param_group['lr']
|
13 |
+
|
14 |
+
class MLP(nn.Module):
|
15 |
+
def __init__(self, f_i: int, f_o: int, act_fn: object = nn.SELU, f=[], insert_in=[4], freq_encoding=False):
|
16 |
+
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.insert_in = insert_in
|
20 |
+
self.layers = nn.ModuleList()
|
21 |
+
self.act = act_fn()
|
22 |
+
self.freq_encoding = freq_encoding
|
23 |
+
f_in = f_i
|
24 |
+
for f_cntr, f_oo in enumerate(f):
|
25 |
+
if f_cntr in insert_in:
|
26 |
+
f_oo -= f_i
|
27 |
+
self.layers.append(nn.Linear(f_in, f_oo))
|
28 |
+
if f_cntr in insert_in:
|
29 |
+
f_in = f_oo+f_i
|
30 |
+
else:
|
31 |
+
f_in = f_oo
|
32 |
+
|
33 |
+
self.layers.append(nn.Linear(f_in, f_o))
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
|
37 |
+
res = []
|
38 |
+
if self.freq_encoding:
|
39 |
+
|
40 |
+
yc = x[:,-1:]
|
41 |
+
pi = torch.acos(torch.zeros(1)).item()
|
42 |
+
x = torch.cat((torch.cos(2.*pi*yc),
|
43 |
+
torch.sin(2.*pi*yc),
|
44 |
+
x[:,:-1]), axis=-1)
|
45 |
+
|
46 |
+
inp = x
|
47 |
+
|
48 |
+
for m_ind, m in enumerate(self.layers[:-1]):
|
49 |
+
x = m(x)
|
50 |
+
|
51 |
+
if m_ind in self.insert_in:
|
52 |
+
x = torch.cat((inp,x), axis=1)
|
53 |
+
|
54 |
+
for r in res:
|
55 |
+
x += r
|
56 |
+
|
57 |
+
x = self.act(x)
|
58 |
+
|
59 |
+
res.append(x)
|
60 |
+
|
61 |
+
x = self.layers[-1](x)
|
62 |
+
|
63 |
+
return x
|
64 |
+
|
65 |
+
def one_epoch_mlp(mlp, epoch, loader, optimizer, device, is_train=False):
|
66 |
+
running_loss = 0.
|
67 |
+
counter = 1
|
68 |
+
loss_fn = torch.nn.L1Loss() #reduction="none")
|
69 |
+
|
70 |
+
if is_train:
|
71 |
+
torch.set_grad_enabled(True)
|
72 |
+
else:
|
73 |
+
torch.set_grad_enabled(False)
|
74 |
+
|
75 |
+
for i, data in enumerate(loader):
|
76 |
+
if 1==1: #is_train or i%10==0:
|
77 |
+
optimizer.zero_grad()
|
78 |
+
|
79 |
+
x = data[0].to(device)
|
80 |
+
y = data[1].to(device)
|
81 |
+
T = mlp(x.to(device))
|
82 |
+
|
83 |
+
loss = loss_fn(y.to(device), T)
|
84 |
+
|
85 |
+
if is_train:
|
86 |
+
loss.backward()
|
87 |
+
optimizer.step()
|
88 |
+
|
89 |
+
running_loss += loss.item()
|
90 |
+
counter += 1
|
91 |
+
|
92 |
+
if 1==2: #i % 1000 == 0:
|
93 |
+
print(epoch, i, running_loss/counter)
|
94 |
+
return running_loss/counter
|
95 |
+
|
96 |
+
def one_epoch_mlp_lbfgs(mlp, epoch, loader, optimizer, device, is_train=False):
|
97 |
+
running_loss = 0.
|
98 |
+
counter = 1
|
99 |
+
loss_fn = torch.nn.L1Loss() #reduction="none")
|
100 |
+
|
101 |
+
if is_train:
|
102 |
+
torch.set_grad_enabled(True)
|
103 |
+
else:
|
104 |
+
torch.set_grad_enabled(False)
|
105 |
+
|
106 |
+
for i, data in enumerate(loader):
|
107 |
+
x = data[0].to(device)
|
108 |
+
y = data[1].to(device)
|
109 |
+
def closure():
|
110 |
+
optimizer.zero_grad()
|
111 |
+
|
112 |
+
T = mlp(x.to(device))
|
113 |
+
loss = loss_fn(y.to(device), T)
|
114 |
+
|
115 |
+
loss.backward()
|
116 |
+
return loss
|
117 |
+
|
118 |
+
if is_train:
|
119 |
+
optimizer.step(closure)
|
120 |
+
|
121 |
+
T = mlp(data[0].to(device))
|
122 |
+
loss = loss_fn(data[1].to(device), T)
|
123 |
+
|
124 |
+
running_loss += loss.item()
|
125 |
+
counter += 1
|
126 |
+
|
127 |
+
return running_loss/counter
|
128 |
+
|
129 |
+
def exists(val):
|
130 |
+
return val is not None
|
131 |
+
|
132 |
+
def cast_tuple(val, repeat = 1):
|
133 |
+
return val if isinstance(val, tuple) else ((val,) * repeat)
|
134 |
+
|
135 |
+
# sin activation
|
136 |
+
class Sine(nn.Module):
|
137 |
+
def __init__(self, w0 = 1.):
|
138 |
+
super().__init__()
|
139 |
+
self.w0 = w0
|
140 |
+
def forward(self, x):
|
141 |
+
return torch.sin(self.w0 * x)
|
142 |
+
|
143 |
+
# siren layer
|
144 |
+
class Siren(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
dim_in,
|
148 |
+
dim_out,
|
149 |
+
w0 = 1.,
|
150 |
+
c = 6.,
|
151 |
+
is_first = False,
|
152 |
+
use_bias = True,
|
153 |
+
activation = None,
|
154 |
+
dropout = 0.
|
155 |
+
):
|
156 |
+
super().__init__()
|
157 |
+
self.dim_in = dim_in
|
158 |
+
self.is_first = is_first
|
159 |
+
|
160 |
+
weight = torch.zeros(dim_out, dim_in)
|
161 |
+
bias = torch.zeros(dim_out) if use_bias else None
|
162 |
+
self.init_(weight, bias, c = c, w0 = w0)
|
163 |
+
|
164 |
+
self.weight = nn.Parameter(weight)
|
165 |
+
self.bias = nn.Parameter(bias) if use_bias else None
|
166 |
+
self.activation = Sine(w0) if activation is None else activation
|
167 |
+
self.dropout = nn.Dropout(dropout)
|
168 |
+
|
169 |
+
def init_(self, weight, bias, c, w0):
|
170 |
+
dim = self.dim_in
|
171 |
+
|
172 |
+
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
|
173 |
+
weight.uniform_(-w_std, w_std)
|
174 |
+
|
175 |
+
if exists(bias):
|
176 |
+
bias.uniform_(-w_std, w_std)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
out = F.linear(x, self.weight, self.bias)
|
180 |
+
out = self.activation(out)
|
181 |
+
out = self.dropout(out)
|
182 |
+
return out
|
183 |
+
|
184 |
+
class SirenMLP(nn.Module):
|
185 |
+
|
186 |
+
def __init__(self,
|
187 |
+
n_in,
|
188 |
+
n_out,
|
189 |
+
n_hidden,
|
190 |
+
device,
|
191 |
+
num_layers,
|
192 |
+
w0 = 30.,
|
193 |
+
w0_initial = 30.,
|
194 |
+
use_bias = True,
|
195 |
+
final_activation = None,
|
196 |
+
dropout = 0.,
|
197 |
+
context_params = None
|
198 |
+
):
|
199 |
+
super().__init__() #SirenCavia, self
|
200 |
+
|
201 |
+
self.device = device
|
202 |
+
self.num_layers = num_layers
|
203 |
+
self.n_hidden = n_hidden
|
204 |
+
self.dropout = nn.Dropout(dropout)
|
205 |
+
|
206 |
+
# siren layers
|
207 |
+
self.siren_layers = nn.ModuleList()
|
208 |
+
|
209 |
+
for ind in range(num_layers):
|
210 |
+
is_first = ind == 0
|
211 |
+
layer_w0 = w0_initial if is_first else w0
|
212 |
+
layer_dim_in = n_in if is_first else n_hidden
|
213 |
+
|
214 |
+
layer = Siren(
|
215 |
+
dim_in = layer_dim_in,
|
216 |
+
dim_out = n_hidden,
|
217 |
+
w0 = layer_w0,
|
218 |
+
use_bias = use_bias,
|
219 |
+
is_first = is_first,
|
220 |
+
dropout = dropout
|
221 |
+
)
|
222 |
+
|
223 |
+
self.siren_layers.append(layer)
|
224 |
+
|
225 |
+
final_activation = nn.Identity() if not exists(final_activation) else final_activation
|
226 |
+
self.siren_layers.append(Siren(dim_in = n_hidden, dim_out = n_out, w0 = w0,
|
227 |
+
use_bias = use_bias, activation = final_activation))
|
228 |
+
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
|
232 |
+
res = []
|
233 |
+
for k in range(len(self.siren_layers) - 1):
|
234 |
+
x = self.siren_layers[k](x)
|
235 |
+
for r in res:
|
236 |
+
x = 0.5*(x+r)
|
237 |
+
res.append(x)
|
238 |
+
x = self.dropout(x)
|
239 |
+
|
240 |
+
x = self.siren_layers[-1](x)
|
241 |
+
|
242 |
+
return x
|
data/train_profiles_mlp.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob, os, sys
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from mlp import *
|
8 |
+
import argparse
|
9 |
+
from datasetio import *
|
10 |
+
from torch.utils.data import TensorDataset
|
11 |
+
|
12 |
+
import copy
|
13 |
+
import pickle
|
14 |
+
import time
|
15 |
+
|
16 |
+
# In[ ]:
|
17 |
+
|
18 |
+
|
19 |
+
data_dir = "/plp_scr1/agar_sh/data/TPH/"
|
20 |
+
nn_dir = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/trained_networks/"
|
21 |
+
|
22 |
+
|
23 |
+
# In[ ]:
|
24 |
+
|
25 |
+
|
26 |
+
run_cell = True
|
27 |
+
if run_cell:
|
28 |
+
parser = argparse.ArgumentParser(description='Train mlp')
|
29 |
+
parser.add_argument("-gpu", "--gpu_number", type=int, help="specify gpu number")
|
30 |
+
parser.add_argument("-a", "--act_fn", type=str, help ="activation function")
|
31 |
+
parser.add_argument("-l", "--num_layers", type=int, help ="activation function")
|
32 |
+
parser.add_argument("-f", "--f_h", type=int, help ="filters")
|
33 |
+
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
gpu_number = args.gpu_number
|
37 |
+
act_fn = args.act_fn
|
38 |
+
num_layers = args.num_layers
|
39 |
+
f_h = [args.f_h]
|
40 |
+
|
41 |
+
else:
|
42 |
+
f_h = [128]
|
43 |
+
num_layers = 3
|
44 |
+
act_fn = "selu"
|
45 |
+
gpu_number = 6
|
46 |
+
|
47 |
+
|
48 |
+
# In[ ]:
|
49 |
+
|
50 |
+
|
51 |
+
f_nn = "mlp_profile_pointwise_" + str(f_h) + "_" + str(num_layers) + "_" + act_fn
|
52 |
+
|
53 |
+
if not os.path.isdir(nn_dir + f_nn):
|
54 |
+
os.mkdir(nn_dir + f_nn)
|
55 |
+
|
56 |
+
device = torch.device("cuda:" + str(gpu_number)) if torch.cuda.is_available() else torch.device("cpu")
|
57 |
+
|
58 |
+
epoch = 0
|
59 |
+
start_lr = 1e-3
|
60 |
+
milestones = [10, 20, 30, 40, 50, 60]
|
61 |
+
epochs = 70
|
62 |
+
best_vloss = 1e+16
|
63 |
+
batch_size = 32
|
64 |
+
|
65 |
+
f_i = 4
|
66 |
+
f_o = 1
|
67 |
+
|
68 |
+
mlp = MLP(f_i=f_i, f_o=f_o, f=f_h*num_layers, insert_in=[num_layers-1]).double().to(device)
|
69 |
+
torch.compile(mlp)
|
70 |
+
print(mlp)
|
71 |
+
|
72 |
+
nn_dir = nn_dir + f_nn + "/"
|
73 |
+
with open(nn_dir + "mlp.txt", 'w') as writer:
|
74 |
+
writer.write('Epoch, train loss, val loss, learning rate \n')
|
75 |
+
|
76 |
+
|
77 |
+
# In[ ]:
|
78 |
+
|
79 |
+
|
80 |
+
dataset = {}
|
81 |
+
loader = {}
|
82 |
+
batches = {}
|
83 |
+
pre = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/profiles/"
|
84 |
+
with open(pre + 'x_pointwise.pkl', 'rb') as file:
|
85 |
+
x_pointwise = pickle.load(file)
|
86 |
+
with open(pre + 'y_pointwise.pkl', 'rb') as file:
|
87 |
+
y_pointwise = pickle.load(file)
|
88 |
+
|
89 |
+
for an in ["train", "cv"]:
|
90 |
+
|
91 |
+
'''
|
92 |
+
inds = x_pointwise[an][:,-1]>0.8
|
93 |
+
x_ = x_pointwise[an][inds,:]
|
94 |
+
y_ = y_pointwise[an][inds,:]
|
95 |
+
|
96 |
+
repeats = int(np.ceil(x_pointwise[an].shape[0]/(batch_size*inds.shape[0])))
|
97 |
+
|
98 |
+
x_pointwise[an] = np.concatenate((x_pointwise[an],
|
99 |
+
np.repeat(x_, repeats, axis=0)
|
100 |
+
), axis=0)
|
101 |
+
y_pointwise[an] = np.concatenate((y_pointwise[an],
|
102 |
+
np.repeat(y_, repeats, axis=0)
|
103 |
+
), axis=0)
|
104 |
+
'''
|
105 |
+
|
106 |
+
dataset[an] = TensorDataset(torch.tensor(x_pointwise[an], dtype=torch.float64),
|
107 |
+
torch.tensor(y_pointwise[an], dtype=torch.float64)) #.view(-1,2))
|
108 |
+
batches[an] = int(len(dataset[an])/batch_size)
|
109 |
+
loader[an] = DataLoader(dataset[an], batch_size=batch_size, shuffle=True)
|
110 |
+
print(an, len(dataset[an]), batches[an])
|
111 |
+
|
112 |
+
|
113 |
+
# In[ ]:
|
114 |
+
|
115 |
+
|
116 |
+
optimizer = torch.optim.Adam([
|
117 |
+
{"params": mlp.parameters(),
|
118 |
+
"lr": start_lr,
|
119 |
+
"weight_decay": 5e-4
|
120 |
+
}
|
121 |
+
])
|
122 |
+
|
123 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)
|
124 |
+
|
125 |
+
|
126 |
+
# In[ ]:
|
127 |
+
|
128 |
+
|
129 |
+
for epoch in range(epochs):
|
130 |
+
t0 = time.time()
|
131 |
+
mlp.train(True)
|
132 |
+
avg_loss = one_epoch_mlp(mlp, epoch, loader["train"], optimizer, device, is_train=True)
|
133 |
+
|
134 |
+
mlp.eval()
|
135 |
+
avg_vloss = one_epoch_mlp(mlp, epoch, loader["cv"], optimizer, device, is_train=False)
|
136 |
+
|
137 |
+
print("-------------------------------------------")
|
138 |
+
print(epoch, "train: ", avg_loss, get_lr(optimizer))
|
139 |
+
print(epoch, "cv: ", avg_vloss)
|
140 |
+
print("took: " + str(time.time() - t0))
|
141 |
+
print("-------------------------------------------")
|
142 |
+
|
143 |
+
if avg_vloss < best_vloss:
|
144 |
+
best_vloss = avg_vloss
|
145 |
+
torch.save(mlp.state_dict(), nn_dir + "mlp.pt")
|
146 |
+
|
147 |
+
with open(nn_dir + "mlp.txt", "a") as writer:
|
148 |
+
writer.write(str(epoch) + "," + str(avg_loss)
|
149 |
+
+ "," + str(avg_vloss) + "," + str(get_lr(optimizer)) + "\n")
|
150 |
+
|
151 |
+
scheduler.step()
|
152 |
+
|
data/x_pointwise.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:277519100cb00cb7e059fd620358ac67ffe85a8cbb1daf0565f028ed68a76000
|
3 |
+
size 1138964
|
data/y_pointwise.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50b39995e959275be47e658fd483bc169137705bb3ca8859272150767f0ca867
|
3 |
+
size 284930
|