dkoshman commited on
Commit
57273ba
1 Parent(s): 4f4785c
Files changed (4) hide show
  1. data_preprocessing.py +6 -4
  2. model.py +5 -3
  3. train.py +7 -10
  4. utils.py +1 -1
data_preprocessing.py CHANGED
@@ -15,7 +15,7 @@ import re
15
  TEX_VOCAB_SIZE = 300
16
  IMAGE_WIDTH = 1024
17
  IMAGE_HEIGHT = 128
18
- BATCH_SIZE = 8
19
  NUM_WORKERS = 4
20
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
21
  PIN_MEMORY = False # probably causes cuda oom error if True
@@ -146,10 +146,10 @@ class ExtractEquationFromTexTransform(object):
146
  return equation
147
 
148
 
149
- def generate_tex_tokenizer(dataset, vocab_size):
150
  """Returns a tokenizer trained on texs from given dataset"""
151
 
152
- texs = list(tqdm.tqdm((item['tex'] for item in dataset), "Training tokenizer", total=len(dataset)))
153
 
154
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
155
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
@@ -197,7 +197,9 @@ class LatexImageDataModule(pl.LightningDataModule):
197
  self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
198
  self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
199
 
200
- self.tex_tokenizer = generate_tex_tokenizer(self.train_dataset, vocab_size=TEX_VOCAB_SIZE)
 
 
201
  self.collate_fn = BatchCollator(self.tex_tokenizer)
202
 
203
  @staticmethod
 
15
  TEX_VOCAB_SIZE = 300
16
  IMAGE_WIDTH = 1024
17
  IMAGE_HEIGHT = 128
18
+ BATCH_SIZE = 16
19
  NUM_WORKERS = 4
20
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
21
  PIN_MEMORY = False # probably causes cuda oom error if True
 
146
  return equation
147
 
148
 
149
+ def generate_tex_tokenizer(dataloader, vocab_size):
150
  """Returns a tokenizer trained on texs from given dataset"""
151
 
152
+ texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
153
 
154
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
155
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
 
197
  self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
198
  self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
199
 
200
+ self.tex_tokenizer = generate_tex_tokenizer(
201
+ DataLoader(self.train_dataset, batch_size=32, num_workers=16),
202
+ vocab_size=TEX_VOCAB_SIZE)
203
  self.collate_fn = BatchCollator(self.tex_tokenizer)
204
 
205
  @staticmethod
model.py CHANGED
@@ -111,8 +111,7 @@ class Transformer(pl.LightningModule):
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
114
- learning_rate=1e-3,
115
- tex_tokenizer=None
116
  ):
117
  super().__init__()
118
 
@@ -133,7 +132,6 @@ class Transformer(pl.LightningModule):
133
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
134
  self.learning_rate = learning_rate
135
  self.save_hyperparameters()
136
- self.tex_tokenizer = tex_tokenizer
137
 
138
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
139
  tgt_padding_mask=None):
@@ -185,6 +183,10 @@ class Transformer(pl.LightningModule):
185
  scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)
186
  return [optimizer], [scheduler]
187
 
 
 
 
 
188
 
189
  class _TransformerTuner(Transformer):
190
  """
 
111
  pad_idx: int,
112
  dim_feedforward: int = 512,
113
  dropout: float = .1,
114
+ learning_rate: float = 1e-3
 
115
  ):
116
  super().__init__()
117
 
 
132
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)
133
  self.learning_rate = learning_rate
134
  self.save_hyperparameters()
 
135
 
136
  def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None, src_padding_mask=None,
137
  tgt_padding_mask=None):
 
183
  scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=1)
184
  return [optimizer], [scheduler]
185
 
186
+ # def configure_optimizers(self):
187
+ # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
188
+ # return optimizer
189
+
190
 
191
  class _TransformerTuner(Transformer):
192
  """
train.py CHANGED
@@ -4,10 +4,10 @@ from model import Transformer, _TransformerTuner
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
 
7
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
8
  from pytorch_lightning import Trainer, seed_everything
9
  import torch
10
- import wandb
11
 
12
  DATASET_PATH = "resources/dataset.pt"
13
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
@@ -58,13 +58,15 @@ def main():
58
  torch.save(datamodule, DATASET_PATH)
59
  else:
60
  datamodule = torch.load(DATASET_PATH)
61
-
62
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
63
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
64
 
65
  if args.log:
66
  logger = WandbLogger(f"img2tex", log_model=True)
67
- callbacks = [LogImageTexCallback(logger, datamodule.tex_tokenizer)]
 
 
 
68
  else:
69
  logger = None
70
  callbacks = []
@@ -88,15 +90,10 @@ def main():
88
  tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
89
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
90
  dim_feedforward=512,
91
- dropout=0.1
 
92
  )
93
 
94
- # dl = datamodule.train_dataloader()
95
- # b = next(iter(dl))
96
- # image=b['images'][0]
97
- # tex = decode(transformer, datamodule.tex_tokenizer, image)
98
- # print(tex)
99
-
100
  # if args.new_dataset:
101
  # datamodule.batch_size = 1
102
  # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
 
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
7
+ from pytorch_lightning.callbacks import LearningRateMonitor
8
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
9
  from pytorch_lightning import Trainer, seed_everything
10
  import torch
 
11
 
12
  DATASET_PATH = "resources/dataset.pt"
13
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
 
58
  torch.save(datamodule, DATASET_PATH)
59
  else:
60
  datamodule = torch.load(DATASET_PATH)
 
61
  # TODO: log images, accuracy?, update python, write own transformer, add checkpoints, lr scheduler,
62
  # determine when trainer doesnt hang(when single gpu,ddp, num_workers=0)
63
 
64
  if args.log:
65
  logger = WandbLogger(f"img2tex", log_model=True)
66
+ callbacks = [
67
+ LogImageTexCallback(logger, datamodule.tex_tokenizer),
68
+ LearningRateMonitor(logging_interval='step')
69
+ ]
70
  else:
71
  logger = None
72
  callbacks = []
 
90
  tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
91
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
92
  dim_feedforward=512,
93
+ dropout=0.1,
94
+ learning_rate=1e-3
95
  )
96
 
 
 
 
 
 
 
97
  # if args.new_dataset:
98
  # datamodule.batch_size = 1
99
  # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
utils.py CHANGED
@@ -18,4 +18,4 @@ class LogImageTexCallback(Callback):
18
  tex_predicted = decode(transformer, self.tex_tokenizer, image)
19
  image = self.tensor_to_PIL(image)
20
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][0].to('cpu', torch.int)), skip_special_tokens=True)
21
- self.logger.log_image(key="samples", images=[image], caption=[f"True {tex_true}\n Predicted{tex_predicted}"])
 
18
  tex_predicted = decode(transformer, self.tex_tokenizer, image)
19
  image = self.tensor_to_PIL(image)
20
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][0].to('cpu', torch.int)), skip_special_tokens=True)
21
+ self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\n Predicted: {tex_predicted}"])