Agarwal commited on
Commit
bf9ef4a
1 Parent(s): ffa6dc5

added data

Browse files
data/mlp.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.lines import Line2D
8
+ import math
9
+
10
+ def get_lr(optimizer):
11
+ for param_group in optimizer.param_groups:
12
+ return param_group['lr']
13
+
14
+ class MLP(nn.Module):
15
+ def __init__(self, f_i: int, f_o: int, act_fn: object = nn.SELU, f=[], insert_in=[4], freq_encoding=False):
16
+
17
+ super().__init__()
18
+
19
+ self.insert_in = insert_in
20
+ self.layers = nn.ModuleList()
21
+ self.act = act_fn()
22
+ self.freq_encoding = freq_encoding
23
+ f_in = f_i
24
+ for f_cntr, f_oo in enumerate(f):
25
+ if f_cntr in insert_in:
26
+ f_oo -= f_i
27
+ self.layers.append(nn.Linear(f_in, f_oo))
28
+ if f_cntr in insert_in:
29
+ f_in = f_oo+f_i
30
+ else:
31
+ f_in = f_oo
32
+
33
+ self.layers.append(nn.Linear(f_in, f_o))
34
+
35
+ def forward(self, x):
36
+
37
+ res = []
38
+ if self.freq_encoding:
39
+
40
+ yc = x[:,-1:]
41
+ pi = torch.acos(torch.zeros(1)).item()
42
+ x = torch.cat((torch.cos(2.*pi*yc),
43
+ torch.sin(2.*pi*yc),
44
+ x[:,:-1]), axis=-1)
45
+
46
+ inp = x
47
+
48
+ for m_ind, m in enumerate(self.layers[:-1]):
49
+ x = m(x)
50
+
51
+ if m_ind in self.insert_in:
52
+ x = torch.cat((inp,x), axis=1)
53
+
54
+ for r in res:
55
+ x += r
56
+
57
+ x = self.act(x)
58
+
59
+ res.append(x)
60
+
61
+ x = self.layers[-1](x)
62
+
63
+ return x
64
+
65
+ def one_epoch_mlp(mlp, epoch, loader, optimizer, device, is_train=False):
66
+ running_loss = 0.
67
+ counter = 1
68
+ loss_fn = torch.nn.L1Loss() #reduction="none")
69
+
70
+ if is_train:
71
+ torch.set_grad_enabled(True)
72
+ else:
73
+ torch.set_grad_enabled(False)
74
+
75
+ for i, data in enumerate(loader):
76
+ if 1==1: #is_train or i%10==0:
77
+ optimizer.zero_grad()
78
+
79
+ x = data[0].to(device)
80
+ y = data[1].to(device)
81
+ T = mlp(x.to(device))
82
+
83
+ loss = loss_fn(y.to(device), T)
84
+
85
+ if is_train:
86
+ loss.backward()
87
+ optimizer.step()
88
+
89
+ running_loss += loss.item()
90
+ counter += 1
91
+
92
+ if 1==2: #i % 1000 == 0:
93
+ print(epoch, i, running_loss/counter)
94
+ return running_loss/counter
95
+
96
+ def one_epoch_mlp_lbfgs(mlp, epoch, loader, optimizer, device, is_train=False):
97
+ running_loss = 0.
98
+ counter = 1
99
+ loss_fn = torch.nn.L1Loss() #reduction="none")
100
+
101
+ if is_train:
102
+ torch.set_grad_enabled(True)
103
+ else:
104
+ torch.set_grad_enabled(False)
105
+
106
+ for i, data in enumerate(loader):
107
+ x = data[0].to(device)
108
+ y = data[1].to(device)
109
+ def closure():
110
+ optimizer.zero_grad()
111
+
112
+ T = mlp(x.to(device))
113
+ loss = loss_fn(y.to(device), T)
114
+
115
+ loss.backward()
116
+ return loss
117
+
118
+ if is_train:
119
+ optimizer.step(closure)
120
+
121
+ T = mlp(data[0].to(device))
122
+ loss = loss_fn(data[1].to(device), T)
123
+
124
+ running_loss += loss.item()
125
+ counter += 1
126
+
127
+ return running_loss/counter
128
+
129
+ def exists(val):
130
+ return val is not None
131
+
132
+ def cast_tuple(val, repeat = 1):
133
+ return val if isinstance(val, tuple) else ((val,) * repeat)
134
+
135
+ # sin activation
136
+ class Sine(nn.Module):
137
+ def __init__(self, w0 = 1.):
138
+ super().__init__()
139
+ self.w0 = w0
140
+ def forward(self, x):
141
+ return torch.sin(self.w0 * x)
142
+
143
+ # siren layer
144
+ class Siren(nn.Module):
145
+ def __init__(
146
+ self,
147
+ dim_in,
148
+ dim_out,
149
+ w0 = 1.,
150
+ c = 6.,
151
+ is_first = False,
152
+ use_bias = True,
153
+ activation = None,
154
+ dropout = 0.
155
+ ):
156
+ super().__init__()
157
+ self.dim_in = dim_in
158
+ self.is_first = is_first
159
+
160
+ weight = torch.zeros(dim_out, dim_in)
161
+ bias = torch.zeros(dim_out) if use_bias else None
162
+ self.init_(weight, bias, c = c, w0 = w0)
163
+
164
+ self.weight = nn.Parameter(weight)
165
+ self.bias = nn.Parameter(bias) if use_bias else None
166
+ self.activation = Sine(w0) if activation is None else activation
167
+ self.dropout = nn.Dropout(dropout)
168
+
169
+ def init_(self, weight, bias, c, w0):
170
+ dim = self.dim_in
171
+
172
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
173
+ weight.uniform_(-w_std, w_std)
174
+
175
+ if exists(bias):
176
+ bias.uniform_(-w_std, w_std)
177
+
178
+ def forward(self, x):
179
+ out = F.linear(x, self.weight, self.bias)
180
+ out = self.activation(out)
181
+ out = self.dropout(out)
182
+ return out
183
+
184
+ class SirenMLP(nn.Module):
185
+
186
+ def __init__(self,
187
+ n_in,
188
+ n_out,
189
+ n_hidden,
190
+ device,
191
+ num_layers,
192
+ w0 = 30.,
193
+ w0_initial = 30.,
194
+ use_bias = True,
195
+ final_activation = None,
196
+ dropout = 0.,
197
+ context_params = None
198
+ ):
199
+ super().__init__() #SirenCavia, self
200
+
201
+ self.device = device
202
+ self.num_layers = num_layers
203
+ self.n_hidden = n_hidden
204
+ self.dropout = nn.Dropout(dropout)
205
+
206
+ # siren layers
207
+ self.siren_layers = nn.ModuleList()
208
+
209
+ for ind in range(num_layers):
210
+ is_first = ind == 0
211
+ layer_w0 = w0_initial if is_first else w0
212
+ layer_dim_in = n_in if is_first else n_hidden
213
+
214
+ layer = Siren(
215
+ dim_in = layer_dim_in,
216
+ dim_out = n_hidden,
217
+ w0 = layer_w0,
218
+ use_bias = use_bias,
219
+ is_first = is_first,
220
+ dropout = dropout
221
+ )
222
+
223
+ self.siren_layers.append(layer)
224
+
225
+ final_activation = nn.Identity() if not exists(final_activation) else final_activation
226
+ self.siren_layers.append(Siren(dim_in = n_hidden, dim_out = n_out, w0 = w0,
227
+ use_bias = use_bias, activation = final_activation))
228
+
229
+
230
+ def forward(self, x):
231
+
232
+ res = []
233
+ for k in range(len(self.siren_layers) - 1):
234
+ x = self.siren_layers[k](x)
235
+ for r in res:
236
+ x = 0.5*(x+r)
237
+ res.append(x)
238
+ x = self.dropout(x)
239
+
240
+ x = self.siren_layers[-1](x)
241
+
242
+ return x
data/train_profiles_mlp.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob, os, sys
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ import torch
7
+ from mlp import *
8
+ import argparse
9
+ from datasetio import *
10
+ from torch.utils.data import TensorDataset
11
+
12
+ import copy
13
+ import pickle
14
+ import time
15
+
16
+ # In[ ]:
17
+
18
+
19
+ data_dir = "/plp_scr1/agar_sh/data/TPH/"
20
+ nn_dir = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/trained_networks/"
21
+
22
+
23
+ # In[ ]:
24
+
25
+
26
+ run_cell = True
27
+ if run_cell:
28
+ parser = argparse.ArgumentParser(description='Train mlp')
29
+ parser.add_argument("-gpu", "--gpu_number", type=int, help="specify gpu number")
30
+ parser.add_argument("-a", "--act_fn", type=str, help ="activation function")
31
+ parser.add_argument("-l", "--num_layers", type=int, help ="activation function")
32
+ parser.add_argument("-f", "--f_h", type=int, help ="filters")
33
+
34
+ args = parser.parse_args()
35
+
36
+ gpu_number = args.gpu_number
37
+ act_fn = args.act_fn
38
+ num_layers = args.num_layers
39
+ f_h = [args.f_h]
40
+
41
+ else:
42
+ f_h = [128]
43
+ num_layers = 3
44
+ act_fn = "selu"
45
+ gpu_number = 6
46
+
47
+
48
+ # In[ ]:
49
+
50
+
51
+ f_nn = "mlp_profile_pointwise_" + str(f_h) + "_" + str(num_layers) + "_" + act_fn
52
+
53
+ if not os.path.isdir(nn_dir + f_nn):
54
+ os.mkdir(nn_dir + f_nn)
55
+
56
+ device = torch.device("cuda:" + str(gpu_number)) if torch.cuda.is_available() else torch.device("cpu")
57
+
58
+ epoch = 0
59
+ start_lr = 1e-3
60
+ milestones = [10, 20, 30, 40, 50, 60]
61
+ epochs = 70
62
+ best_vloss = 1e+16
63
+ batch_size = 32
64
+
65
+ f_i = 4
66
+ f_o = 1
67
+
68
+ mlp = MLP(f_i=f_i, f_o=f_o, f=f_h*num_layers, insert_in=[num_layers-1]).double().to(device)
69
+ torch.compile(mlp)
70
+ print(mlp)
71
+
72
+ nn_dir = nn_dir + f_nn + "/"
73
+ with open(nn_dir + "mlp.txt", 'w') as writer:
74
+ writer.write('Epoch, train loss, val loss, learning rate \n')
75
+
76
+
77
+ # In[ ]:
78
+
79
+
80
+ dataset = {}
81
+ loader = {}
82
+ batches = {}
83
+ pre = "/plp_user/agar_sh/PBML/pytorch/TPH/MLP/profiles/"
84
+ with open(pre + 'x_pointwise.pkl', 'rb') as file:
85
+ x_pointwise = pickle.load(file)
86
+ with open(pre + 'y_pointwise.pkl', 'rb') as file:
87
+ y_pointwise = pickle.load(file)
88
+
89
+ for an in ["train", "cv"]:
90
+
91
+ '''
92
+ inds = x_pointwise[an][:,-1]>0.8
93
+ x_ = x_pointwise[an][inds,:]
94
+ y_ = y_pointwise[an][inds,:]
95
+
96
+ repeats = int(np.ceil(x_pointwise[an].shape[0]/(batch_size*inds.shape[0])))
97
+
98
+ x_pointwise[an] = np.concatenate((x_pointwise[an],
99
+ np.repeat(x_, repeats, axis=0)
100
+ ), axis=0)
101
+ y_pointwise[an] = np.concatenate((y_pointwise[an],
102
+ np.repeat(y_, repeats, axis=0)
103
+ ), axis=0)
104
+ '''
105
+
106
+ dataset[an] = TensorDataset(torch.tensor(x_pointwise[an], dtype=torch.float64),
107
+ torch.tensor(y_pointwise[an], dtype=torch.float64)) #.view(-1,2))
108
+ batches[an] = int(len(dataset[an])/batch_size)
109
+ loader[an] = DataLoader(dataset[an], batch_size=batch_size, shuffle=True)
110
+ print(an, len(dataset[an]), batches[an])
111
+
112
+
113
+ # In[ ]:
114
+
115
+
116
+ optimizer = torch.optim.Adam([
117
+ {"params": mlp.parameters(),
118
+ "lr": start_lr,
119
+ "weight_decay": 5e-4
120
+ }
121
+ ])
122
+
123
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.5)
124
+
125
+
126
+ # In[ ]:
127
+
128
+
129
+ for epoch in range(epochs):
130
+ t0 = time.time()
131
+ mlp.train(True)
132
+ avg_loss = one_epoch_mlp(mlp, epoch, loader["train"], optimizer, device, is_train=True)
133
+
134
+ mlp.eval()
135
+ avg_vloss = one_epoch_mlp(mlp, epoch, loader["cv"], optimizer, device, is_train=False)
136
+
137
+ print("-------------------------------------------")
138
+ print(epoch, "train: ", avg_loss, get_lr(optimizer))
139
+ print(epoch, "cv: ", avg_vloss)
140
+ print("took: " + str(time.time() - t0))
141
+ print("-------------------------------------------")
142
+
143
+ if avg_vloss < best_vloss:
144
+ best_vloss = avg_vloss
145
+ torch.save(mlp.state_dict(), nn_dir + "mlp.pt")
146
+
147
+ with open(nn_dir + "mlp.txt", "a") as writer:
148
+ writer.write(str(epoch) + "," + str(avg_loss)
149
+ + "," + str(avg_vloss) + "," + str(get_lr(optimizer)) + "\n")
150
+
151
+ scheduler.step()
152
+
data/x_pointwise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:277519100cb00cb7e059fd620358ac67ffe85a8cbb1daf0565f028ed68a76000
3
+ size 1138964
data/y_pointwise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50b39995e959275be47e658fd483bc169137705bb3ca8859272150767f0ca867
3
+ size 284930