jwalanthi commited on
Commit
99ad741
1 Parent(s): 35cf171

actual fix hopefully

Browse files
Files changed (2) hide show
  1. app.py +1 -4
  2. model.py +398 -0
app.py CHANGED
@@ -1,14 +1,11 @@
1
  import gradio as gr
2
- # import torch
3
- # import lightning
4
  from minicons import cwe
5
  import pandas as pd
6
  import numpy as np
7
 
8
  from model import FeatureNormPredictor
9
 
10
- import sys
11
- sys.path.insert(0, '/home/jjr4354/semantic-features')
12
 
13
  def predict (word, sentence, lm_name, layer, norm):
14
  if word not in sentence: return "invalid input: word not in sentence"
 
1
  import gradio as gr
2
+ import torch
 
3
  from minicons import cwe
4
  import pandas as pd
5
  import numpy as np
6
 
7
  from model import FeatureNormPredictor
8
 
 
 
9
 
10
  def predict (word, sentence, lm_name, layer, norm):
11
  if word not in sentence: return "invalid input: word not in sentence"
model.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning
3
+ from torch.utils.data import Dataset
4
+ from typing import Any, Dict
5
+ import argparse
6
+ from pydantic import BaseModel
7
+ from get_dataset_dictionaries import get_dict_pair
8
+ import os
9
+ import shutil
10
+
11
+ import optuna
12
+ from optuna.integration import PyTorchLightningPruningCallback
13
+ from functools import partial
14
+
15
+ class FFNModule(torch.nn.Module):
16
+ """
17
+ A pytorch module that regresses from a hidden state representation of a word
18
+ to its continuous linguistic feature norm vector.
19
+
20
+ It is a FFN with the general structure of:
21
+ input -> (linear -> nonlinearity -> dropout) x (num_layers - 1) -> linear -> output
22
+ """
23
+ def __init__(
24
+ self,
25
+ input_size: int,
26
+ output_size: int,
27
+ hidden_size: int,
28
+ num_layers: int,
29
+ dropout: float,
30
+ ):
31
+ super(FFNModule, self).__init__()
32
+
33
+ layers = []
34
+ for _ in range(num_layers - 1):
35
+ layers.append(torch.nn.Linear(input_size, hidden_size))
36
+ layers.append(torch.nn.ReLU())
37
+ layers.append(torch.nn.Dropout(dropout))
38
+ # changes input size to hidden size after first layer
39
+ input_size = hidden_size
40
+ layers.append(torch.nn.Linear(hidden_size, output_size))
41
+ self.network = torch.nn.Sequential(*layers)
42
+
43
+ def forward(self, x):
44
+ return self.network(x)
45
+
46
+ class FFNParams(BaseModel):
47
+ input_size: int
48
+ output_size: int
49
+ hidden_size: int
50
+ num_layers: int
51
+ dropout: float
52
+
53
+ class TrainingParams(BaseModel):
54
+ num_epochs: int
55
+ batch_size: int
56
+ learning_rate: float
57
+ weight_decay: float
58
+
59
+ class FeatureNormPredictor(lightning.LightningModule):
60
+ def __init__(self, ffn_params : FFNParams, training_params : TrainingParams):
61
+ super().__init__()
62
+ self.save_hyperparameters()
63
+ self.ffn_params = ffn_params
64
+ self.training_params = training_params
65
+ self.model = FFNModule(**ffn_params.model_dump())
66
+ self.loss_function = torch.nn.MSELoss()
67
+ self.training_params = training_params
68
+
69
+ def training_step(self, batch, batch_idx):
70
+ x,y = batch
71
+ outputs = self.model(x)
72
+ loss = self.loss_function(outputs, y)
73
+ self.log("train_loss", loss)
74
+ return loss
75
+
76
+ def validation_step(self, batch, batch_idx):
77
+ x,y = batch
78
+ outputs = self.model(x)
79
+ loss = self.loss_function(outputs, y)
80
+ self.log("val_loss", loss, on_epoch=True, prog_bar=True)
81
+ return loss
82
+
83
+ def test_step(self, batch, batch_idx):
84
+ return self.model(batch)
85
+
86
+ def predict(self, batch):
87
+ return self.model(batch)
88
+
89
+ def __call__(self, input):
90
+ return self.model(input)
91
+
92
+ def configure_optimizers(self):
93
+ optimizer = torch.optim.Adam(
94
+ self.parameters(),
95
+ lr=self.training_params.learning_rate,
96
+ weight_decay=self.training_params.weight_decay,
97
+ )
98
+ return optimizer
99
+
100
+ def save_model(self, path: str):
101
+ torch.save(self.model.state_dict(), path)
102
+
103
+ def load_model(self, path: str):
104
+ self.model.load_state_dict(torch.load(path))
105
+
106
+
107
+ class HiddenStateFeatureNormDataset(Dataset):
108
+ def __init__(
109
+ self,
110
+ input_embeddings: Dict[str, torch.Tensor],
111
+ feature_norms: Dict[str, torch.Tensor],
112
+ ):
113
+
114
+ # Invariant: input_embeddings and target_feature_norms have exactly the same keys
115
+ # this should be done by the train/test split and upstream data processing
116
+ assert(input_embeddings.keys() == feature_norms.keys())
117
+
118
+ self.words = list(input_embeddings.keys())
119
+ self.input_embeddings = torch.stack([
120
+ input_embeddings[word] for word in self.words
121
+ ])
122
+ self.feature_norms = torch.stack([
123
+ feature_norms[word] for word in self.words
124
+ ])
125
+
126
+ def __len__(self):
127
+ return len(self.words)
128
+
129
+ def __getitem__(self, idx):
130
+ return self.input_embeddings[idx], self.feature_norms[idx]
131
+
132
+ # this is used when not optimizing
133
+ def train(args : Dict[str, Any]):
134
+
135
+ # input_embeddings = torch.load(args.input_embeddings)
136
+ # feature_norms = torch.load(args.feature_norms)
137
+ # words = list(input_embeddings.keys())
138
+
139
+ input_embeddings, feature_norms, norm_list = get_dict_pair(
140
+ args.norm,
141
+ args.embedding_dir,
142
+ args.lm_layer,
143
+ translated= False if args.raw_buchanan else True,
144
+ normalized= True if args.normal_buchanan else False
145
+ )
146
+ norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
147
+ norms_file.write("\n".join(norm_list))
148
+ norms_file.close()
149
+
150
+ words = list(input_embeddings.keys())
151
+
152
+ model = FeatureNormPredictor(
153
+ FFNParams(
154
+ input_size=input_embeddings[words[0]].shape[0],
155
+ output_size=feature_norms[words[0]].shape[0],
156
+ hidden_size=args.hidden_size,
157
+ num_layers=args.num_layers,
158
+ dropout=args.dropout,
159
+ ),
160
+ TrainingParams(
161
+ num_epochs=args.num_epochs,
162
+ batch_size=args.batch_size,
163
+ learning_rate=args.learning_rate,
164
+ weight_decay=args.weight_decay,
165
+ ),
166
+ )
167
+
168
+ # train/val split
169
+ train_size = int(len(words) * 0.8)
170
+ valid_size = len(words) - train_size
171
+ train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])
172
+
173
+ # TODO: Methodology Decision: should we be normalizing the hidden states/feature norms?
174
+ train_embeddings = {word: input_embeddings[word] for word in train_words}
175
+ train_feature_norms = {word: feature_norms[word] for word in train_words}
176
+ validation_embeddings = {word: input_embeddings[word] for word in validation_words}
177
+ validation_feature_norms = {word: feature_norms[word] for word in validation_words}
178
+
179
+ train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
180
+ train_dataloader = torch.utils.data.DataLoader(
181
+ train_dataset,
182
+ batch_size=args.batch_size,
183
+ shuffle=True,
184
+ )
185
+ validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
186
+ validation_dataloader = torch.utils.data.DataLoader(
187
+ validation_dataset,
188
+ batch_size=args.batch_size,
189
+ shuffle=True,
190
+ )
191
+
192
+ callbacks = [
193
+ lightning.pytorch.callbacks.ModelCheckpoint(
194
+ save_last=True,
195
+ dirpath=args.save_dir,
196
+ filename=args.save_model_name,
197
+ ),
198
+ ]
199
+ if args.early_stopping is not None:
200
+ callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
201
+ monitor="val_loss",
202
+ patience=args.early_stopping,
203
+ mode='min',
204
+ min_delta=0.0
205
+ ))
206
+
207
+ #TODO Design Decision - other trainer args? Is device necessary?
208
+ # cpu is fine for the scale of this model - only a few layers and a few hundred words
209
+ trainer = lightning.Trainer(
210
+ max_epochs=args.num_epochs,
211
+ callbacks=callbacks,
212
+ accelerator="cpu",
213
+ log_every_n_steps=7
214
+ )
215
+
216
+ trainer.fit(model, train_dataloader, validation_dataloader)
217
+
218
+ trainer.validate(model, validation_dataloader)
219
+
220
+ return model
221
+
222
+ # this is used when optimizing
223
+ def objective(trial: optuna.trial.Trial, args: Dict[str, Any]) -> float:
224
+ # optimizing hidden size, batch size, and learning rate
225
+ input_embeddings, feature_norms, norm_list = get_dict_pair(
226
+ args.norm,
227
+ args.embedding_dir,
228
+ args.lm_layer,
229
+ translated= False if args.raw_buchanan else True,
230
+ normalized= True if args.normal_buchanan else False
231
+ )
232
+ norms_file = open(args.save_dir+"/"+args.save_model_name+'.txt','w')
233
+ norms_file.write("\n".join(norm_list))
234
+ norms_file.close()
235
+
236
+ words = list(input_embeddings.keys())
237
+ input_size=input_embeddings[words[0]].shape[0]
238
+ output_size=feature_norms[words[0]].shape[0]
239
+ min_size = min(output_size, input_size)
240
+ max_size = min(output_size, 2*input_size)if min_size == input_size else min(2*output_size, input_size)
241
+ hidden_size = trial.suggest_int("hidden_size", min_size, max_size, log=True)
242
+ batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
243
+ learning_rate = trial.suggest_float("learning_rate", 1e-6, 1, log=True)
244
+
245
+ model = FeatureNormPredictor(
246
+ FFNParams(
247
+ input_size=input_size,
248
+ output_size=output_size,
249
+ hidden_size=hidden_size,
250
+ num_layers=args.num_layers,
251
+ dropout=args.dropout,
252
+ ),
253
+ TrainingParams(
254
+ num_epochs=args.num_epochs,
255
+ batch_size=batch_size,
256
+ learning_rate=learning_rate,
257
+ weight_decay=args.weight_decay,
258
+ ),
259
+ )
260
+
261
+ # train/val split
262
+ train_size = int(len(words) * 0.8)
263
+ valid_size = len(words) - train_size
264
+ train_words, validation_words = torch.utils.data.random_split(words, [train_size, valid_size])
265
+
266
+ train_embeddings = {word: input_embeddings[word] for word in train_words}
267
+ train_feature_norms = {word: feature_norms[word] for word in train_words}
268
+ validation_embeddings = {word: input_embeddings[word] for word in validation_words}
269
+ validation_feature_norms = {word: feature_norms[word] for word in validation_words}
270
+
271
+ train_dataset = HiddenStateFeatureNormDataset(train_embeddings, train_feature_norms)
272
+ train_dataloader = torch.utils.data.DataLoader(
273
+ train_dataset,
274
+ batch_size=args.batch_size,
275
+ shuffle=True,
276
+ )
277
+ validation_dataset = HiddenStateFeatureNormDataset(validation_embeddings, validation_feature_norms)
278
+ validation_dataloader = torch.utils.data.DataLoader(
279
+ validation_dataset,
280
+ batch_size=args.batch_size,
281
+ shuffle=True,
282
+ )
283
+
284
+ callbacks = [
285
+ # all trial models will be saved in temporary directory
286
+ lightning.pytorch.callbacks.ModelCheckpoint(
287
+ save_last=True,
288
+ dirpath=os.path.join(args.save_dir,'optuna_trials'),
289
+ filename="{}".format(trial.number)
290
+ ),
291
+ ]
292
+
293
+ if args.prune is not None:
294
+ callbacks.append(PyTorchLightningPruningCallback(
295
+ trial,
296
+ monitor='val_loss'
297
+ ))
298
+
299
+ if args.early_stopping is not None:
300
+ callbacks.append(lightning.pytorch.callbacks.EarlyStopping(
301
+ monitor="val_loss",
302
+ patience=args.early_stopping,
303
+ mode='min',
304
+ min_delta=0.0
305
+ ))
306
+ # note that if optimizing is chosen, will automatically not implement vanilla early stopping
307
+ #TODO Design Decision - other trainer args? Is device necessary?
308
+ # cpu is fine for the scale of this model - only a few layers and a few hundred words
309
+ trainer = lightning.Trainer(
310
+ max_epochs=args.num_epochs,
311
+ callbacks=callbacks,
312
+ accelerator="cpu",
313
+ log_every_n_steps=7,
314
+ # enable_checkpointing=False
315
+ )
316
+
317
+ trainer.fit(model, train_dataloader, validation_dataloader)
318
+
319
+ trainer.validate(model, validation_dataloader)
320
+
321
+ return trainer.callback_metrics['val_loss'].item()
322
+
323
+ if __name__ == "__main__":
324
+ # parse args
325
+ parser = argparse.ArgumentParser()
326
+ #TODO: Design Decision: Should we input paths, to the pre-extracted layers, or the model/layer we want to generate them from
327
+ # required inputs
328
+ parser.add_argument("--norm", type=str, required=True, help="feature norm set to use")
329
+ parser.add_argument("--embedding_dir", type=str, required=True, help=" directory containing embeddings")
330
+ parser.add_argument("--lm_layer", type=int, required=True, help="layer of embeddings to use")
331
+ # if user selects optimize, hidden_size, batch_size and learning_rate will be optimized.
332
+ parser.add_argument("--optimize", action="store_true", help="optimize hyperparameters for training")
333
+ parser.add_argument("--prune", action="store_true", help="prune unpromising trials when optimizing")
334
+ # optional hyperparameter specs
335
+ parser.add_argument("--num_layers", type=int, default=2, help="number of layers in FFN")
336
+ parser.add_argument("--hidden_size", type=int, default=100, help="hidden size of FFN")
337
+ parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate of FFN")
338
+ # set this to at least 100 if doing early stopping
339
+ parser.add_argument("--num_epochs", type=int, default=10, help="number of epochs to train for")
340
+ parser.add_argument("--batch_size", type=int, default=32, help="batch size for training")
341
+ parser.add_argument("--learning_rate", type=float, default=0.001, help="learning rate for training")
342
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay for training")
343
+ parser.add_argument("--early_stopping", type=int, default=None, help="number of epochs to wait for early stopping")
344
+ # optional dataset specs, for buchanan really
345
+ parser.add_argument('--raw_buchanan', action="store_true", help="do not use translated values for buchanan")
346
+ parser.add_argument('--normal_buchanan', action="store_true", help="use normalized features for buchanan")
347
+ # required for output
348
+ parser.add_argument("--save_dir", type=str, required=True, help="directory to save model to")
349
+ parser.add_argument("--save_model_name", type=str, required=True, help="name of model to save")
350
+
351
+ args = parser.parse_args()
352
+
353
+ if args.early_stopping is not None:
354
+ args.num_epochs = max(50, args.num_epochs)
355
+
356
+ torch.manual_seed(10)
357
+
358
+ if args.optimize:
359
+ # call optimizer code here
360
+ print("optimizing for learning rate, batch size, and hidden size")
361
+ pruner = optuna.pruners.MedianPruner() if args.prune else optuna.pruners.NopPruner()
362
+ sampler = optuna.samplers.TPESampler(seed=10)
363
+
364
+ study = optuna.create_study(direction='minimize', pruner=pruner, sampler=sampler)
365
+ study.optimize(partial(objective, args=args), n_trials = 100, timeout=600)
366
+
367
+ other_params = {
368
+ "num_layers": args.num_layers,
369
+ "num_epochs": args.num_epochs,
370
+ "dropout": args.dropout,
371
+ "weight_decay": args.weight_decay,
372
+ }
373
+
374
+ print("Number of finished trials: {}".format(len(study.trials)))
375
+
376
+ trial = study.best_trial
377
+ print("Best trial: "+str(trial.number))
378
+
379
+
380
+ print(" Validation Loss: {}".format(trial.value))
381
+
382
+ print(" Optimized Params: ")
383
+ for key, value in trial.params.items():
384
+ print(" {}: {}".format(key, value))
385
+
386
+ print(" User Defined Params: ")
387
+ for key, value in other_params.items():
388
+ print(" {}: {}".format(key, value))
389
+
390
+ print('saving best trial')
391
+ for filename in os.listdir(os.path.join(args.save_dir,'optuna_trials')):
392
+ if filename == "{}.ckpt".format(trial.number):
393
+ shutil.move(os.path.join(args.save_dir,'optuna_trials',filename), os.path.join(args.save_dir, "{}.ckpt".format(args.save_model_name)))
394
+ shutil.rmtree(os.path.join(args.save_dir,'optuna_trials'))
395
+
396
+ else:
397
+ model = train(args)
398
+