dkoshman commited on
Commit
4f4785c
1 Parent(s): c2ef1c6

added callback on hook, decoder, image logger, tried tuning

Browse files
Files changed (4) hide show
  1. data_preprocessing.py +1 -1
  2. model.py +58 -4
  3. train.py +49 -15
  4. utils.py +21 -0
data_preprocessing.py CHANGED
@@ -213,7 +213,7 @@ class LatexImageDataModule(pl.LightningDataModule):
213
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
214
 
215
  def val_dataloader(self):
216
- return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
217
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
218
 
219
  def test_dataloader(self):
 
213
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
214
 
215
  def val_dataloader(self):
216
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn,
217
  pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
218
 
219
  def test_dataloader(self):
model.py CHANGED
@@ -111,7 +111,9 @@ class Transformer(pl.LightningModule):
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
114
- learning_rate=1e-4):
 
 
115
  super().__init__()
116
 
117
  self.transformer = nn.Transformer(d_model=emb_size,
@@ -130,8 +132,11 @@ class Transformer(pl.LightningModule):
130
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
131
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
132
  self.learning_rate = learning_rate
 
 
133
 
134
- def forward(self, src, tgt, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask):
 
135
  src = self.src_tok_emb(src)
136
  tgt = self.tgt_tok_emb(tgt)
137
 
@@ -176,5 +181,54 @@ class Transformer(pl.LightningModule):
176
  return loss
177
 
178
  def configure_optimizers(self):
179
- # TODO write scheduler
180
- return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
114
+ learning_rate=1e-3,
115
+ tex_tokenizer=None
116
+ ):
117
  super().__init__()
118
 
119
  self.transformer = nn.Transformer(d_model=emb_size,
 
132
  self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
133
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
134
  self.learning_rate = learning_rate
135
+ self.save_hyperparameters()
136
+ self.tex_tokenizer = tex_tokenizer
137
 
138
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
139
+ tgt_padding_mask=None):
140
  src = self.src_tok_emb(src)
141
  tgt = self.tgt_tok_emb(tgt)
142
 
 
181
  return loss
182
 
183
  def configure_optimizers(self):
184
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
185
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)
186
+ return [optimizer], [scheduler]
187
+
188
+
189
+ class _TransformerTuner(Transformer):
190
+ """
191
+ When using trainer.tune, batches from dataloader get passed directly to forward,
192
+ so this subclass takes care of that
193
+ """
194
+
195
+ def forward(self, batch, batch_idx):
196
+ src = batch['images']
197
+ tgt = batch['tex_ids']
198
+ tgt_input = tgt[:, :-1]
199
+ tgt_output = tgt[:, 1:]
200
+ src_mask = None
201
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
202
+ torch.ByteTensor.dtype)
203
+ memory_mask = None
204
+ src_padding_mask = None
205
+ tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
206
+ tgt_padding_mask = tgt_padding_mask.masked_fill(
207
+ tgt_padding_mask == 0, float('-inf')
208
+ ).masked_fill(
209
+ tgt_padding_mask == 1, 0
210
+ )
211
+
212
+ src = self.src_tok_emb(src)
213
+ tgt_input = self.tgt_tok_emb(tgt_input)
214
+ outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
215
+ outs = self.generator(outs)
216
+
217
+ loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
218
+ return loss
219
+
220
+ def validation_step(self, batch, batch_idx):
221
+ return self(batch, batch_idx)
222
+
223
+
224
+ @torch.inference_mode()
225
+ def decode(transformer, tex_tokenizer, image):
226
+ tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
227
+ while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
228
+ src = einops.rearrange(image, "c h w -> () c h w")
229
+ tgt = torch.tensor([tex_ids], device=transformer.device, dtype=torch.float32)
230
+ outs = transformer(src, tgt)
231
+ next_id = outs[:, -1].argmax(dim=1).item()
232
+ tex_ids.append(next_id)
233
+ tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
234
+ return tex
train.py CHANGED
@@ -1,26 +1,32 @@
1
  from data_generator import generate_data
2
  from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
- from model import Transformer
 
4
 
5
  import argparse
6
- from pytorch_lightning.loggers import WandbLogger
7
  from pytorch_lightning import Trainer, seed_everything
8
  import torch
 
9
 
10
- DATASET_PATH = 'resources/dataset.pt'
 
 
 
 
11
 
12
 
13
  def parse_args():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument(
16
- "-m", "-max-epochs", help="limit the number of training epochs", type=int, dest='max_epochs'
17
  )
18
  parser.add_argument(
19
  "-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
20
  dest="new_dataset"
21
  )
22
  parser.add_argument(
23
- "-g", "-gpus", help=f"number of gpus to train on in range 0..{torch.cuda.device_count()}",
24
  type=int, dest="gpus", choices=list(range(torch.cuda.device_count())),
25
  )
26
  parser.add_argument(
@@ -31,6 +37,10 @@ def parse_args():
31
  "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
32
  action="store_true", dest="deterministic"
33
  )
 
 
 
 
34
 
35
  args = parser.parse_args()
36
  return args
@@ -52,17 +62,21 @@ def main():
52
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
53
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
54
 
55
- logger = WandbLogger(f"img2tex", version='0') if args.log else False
 
 
 
 
 
56
 
57
  trainer = Trainer(max_epochs=args.max_epochs,
58
- accelerator='gpu' if args.gpus else 'cpu',
59
  gpus=args.gpus,
60
  logger=logger,
61
- strategy='ddp',
62
- auto_scale_batch_size="power",
63
- auto_lr_find=True,
64
- auto_select_gpus=True,
65
- enable_progress_bar=True
66
  )
67
 
68
  transformer = Transformer(num_encoder_layers=3,
@@ -77,11 +91,31 @@ def main():
77
  dropout=0.1
78
  )
79
 
80
- trainer.tune(transformer, datamodule=datamodule)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  trainer.fit(transformer, datamodule=datamodule)
82
  trainer.test(datamodule=datamodule)
83
- trainer.save_checkpoint("best_model.ckpt")
84
 
85
 
86
- if __name__ == '__main__':
87
  main()
 
1
  from data_generator import generate_data
2
  from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
+ from model import Transformer, _TransformerTuner
4
+ from utils import LogImageTexCallback
5
 
6
  import argparse
7
+ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
8
  from pytorch_lightning import Trainer, seed_everything
9
  import torch
10
+ import wandb
11
 
12
+ DATASET_PATH = "resources/dataset.pt"
13
+ TRAINER_DIR = "resources/pl_trainer_checkpoints"
14
+ TUNER_DIR = "resources/pl_tuner_checkpoints"
15
+ TRAINER_STRATEGY = "ddp"
16
+ BEST_MODEL_CHECKPOINT = "best_model.ckpt"
17
 
18
 
19
  def parse_args():
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument(
22
+ "-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs"
23
  )
24
  parser.add_argument(
25
  "-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
26
  dest="new_dataset"
27
  )
28
  parser.add_argument(
29
+ "-g", "-gpus", metavar="GPUS", help="ids of gpus to train on, if not provided then trains on cpu", nargs="+",
30
  type=int, dest="gpus", choices=list(range(torch.cuda.device_count())),
31
  )
32
  parser.add_argument(
 
37
  "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
38
  action="store_true", dest="deterministic"
39
  )
40
+ # parser.add_argument(
41
+ # "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
42
+ # action="store_true", dest="tune"
43
+ # )
44
 
45
  args = parser.parse_args()
46
  return args
 
62
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
63
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
64
 
65
+ if args.log:
66
+ logger = WandbLogger(f"img2tex", log_model=True)
67
+ callbacks = [LogImageTexCallback(logger, datamodule.tex_tokenizer)]
68
+ else:
69
+ logger = None
70
+ callbacks = []
71
 
72
  trainer = Trainer(max_epochs=args.max_epochs,
73
+ accelerator="cpu" if args.gpus is None else "gpu",
74
  gpus=args.gpus,
75
  logger=logger,
76
+ strategy=TRAINER_STRATEGY,
77
+ enable_progress_bar=True,
78
+ default_root_dir=TRAINER_DIR,
79
+ callbacks=callbacks,
 
80
  )
81
 
82
  transformer = Transformer(num_encoder_layers=3,
 
91
  dropout=0.1
92
  )
93
 
94
+ # dl = datamodule.train_dataloader()
95
+ # b = next(iter(dl))
96
+ # image=b['images'][0]
97
+ # tex = decode(transformer, datamodule.tex_tokenizer, image)
98
+ # print(tex)
99
+
100
+ # if args.new_dataset:
101
+ # datamodule.batch_size = 1
102
+ # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
103
+ # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
104
+ # gpus=args.gpus,
105
+ # strategy=TRAINER_STRATEGY,
106
+ # enable_progress_bar=True,
107
+ # enable_checkpointing=False,
108
+ # auto_scale_batch_size=True,
109
+ # num_sanity_val_steps=0,
110
+ # logger=False
111
+ # )
112
+ # tuner.tune(transformer_for_tuning, datamodule=datamodule)
113
+ # torch.save(datamodule, DATASET_PATH)
114
+
115
  trainer.fit(transformer, datamodule=datamodule)
116
  trainer.test(datamodule=datamodule)
117
+ trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
118
 
119
 
120
+ if __name__ == "__main__":
121
  main()
utils.py CHANGED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_lightning.callbacks import Callback
3
+ from model import decode
4
+
5
+ from torchvision import transforms
6
+
7
+
8
+ class LogImageTexCallback(Callback):
9
+ def __init__(self, logger, tex_tokenizer):
10
+ self.logger = logger
11
+ self.tex_tokenizer = tex_tokenizer
12
+ self.tensor_to_PIL = transforms.ToPILImage()
13
+
14
+ def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
15
+ if batch_idx != 0 or dataloader_idx != 0:
16
+ return
17
+ image = batch['images'][0]
18
+ tex_predicted = decode(transformer, self.tex_tokenizer, image)
19
+ image = self.tensor_to_PIL(image)
20
+ tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][0].to('cpu', torch.int)), skip_special_tokens=True)
21
+ self.logger.log_image(key="samples", images=[image], caption=[f"True {tex_true}\n Predicted{tex_predicted}"])