dkoshman commited on
Commit
c308f77
1 Parent(s): 41a34cd

moved constants to separate file, organized tokenizer

Browse files
Files changed (6) hide show
  1. constants.py +11 -0
  2. data_generator.py +2 -5
  3. data_preprocessing.py +14 -36
  4. model.py +0 -2
  5. train.py +11 -8
  6. utils.py +5 -4
constants.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
2
+ GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
3
+
4
+ DATA_DIR = "data"
5
+ LATEX_PATH = "resources/latex.json"
6
+ TRAINER_DIR = "resources/trainer"
7
+ TOKENIZER_PATH = "resources/tokenizer.pt"
8
+
9
+ NUM_DATALOADER_WORKERS = 4
10
+ PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
11
+ PIN_MEMORY = False # probably causes cuda oom error if True
data_generator.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import json
2
  from multiprocessing import Pool
3
  import os
@@ -7,11 +9,6 @@ import subprocess
7
  import random
8
  import tqdm
9
 
10
- DATA_DIR = "data"
11
- LATEX_PATH = "resources/latex.json"
12
- PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
13
- GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
14
-
15
 
16
  def generate_equation(latex, size, max_depth):
17
  """
 
1
+ from constants import DATA_DIR, LATEX_PATH, PDFLATEX, GHOSTSCRIPT
2
+
3
  import json
4
  from multiprocessing import Pool
5
  import os
 
9
  import random
10
  import tqdm
11
 
 
 
 
 
 
12
 
13
  def generate_equation(latex, size, max_depth):
14
  """
data_preprocessing.py CHANGED
@@ -1,4 +1,4 @@
1
- from data_generator import DATA_DIR
2
 
3
  import einops
4
  import os
@@ -9,14 +9,8 @@ import torchvision
9
  import torchvision.transforms as T
10
  from torch.utils.data import Dataset, DataLoader
11
  import tqdm
12
- import random
13
  import re
14
 
15
- TOKENIZER_PATH = "resources/tokenizer.pt"
16
- NUM_WORKERS = 4
17
- PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
18
- PIN_MEMORY = False # probably causes cuda oom error if True
19
-
20
 
21
  class TexImageDataset(Dataset):
22
  """Image and tex dataset."""
@@ -147,18 +141,12 @@ def generate_tex_tokenizer(dataloader):
147
  class LatexImageDataModule(pl.LightningDataModule):
148
  def __init__(self, image_width, image_height, batch_size, random_magnitude):
149
  super().__init__()
150
- image_transform = RandomizeImageTransform(image_width, image_height, random_magnitude)
151
- tex_transform = ExtractEquationFromTexTransform()
152
-
153
- self.train_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
154
- self.val_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
155
- self.test_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
156
-
157
- train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
158
- self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
159
- self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
160
- self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
161
 
 
 
 
 
 
162
  self.batch_size = batch_size
163
  self.save_hyperparameters()
164
 
@@ -167,27 +155,17 @@ class LatexImageDataModule(pl.LightningDataModule):
167
  print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
168
  torch.save(tokenizer, TOKENIZER_PATH)
169
 
170
- def setup(self, stage=None):
171
- self.tex_tokenizer = torch.load(TOKENIZER_PATH)
172
- self.collate_fn = BatchCollator(self.tex_tokenizer)
173
-
174
- @staticmethod
175
- def train_val_test_split(size, train_fraction=.8, val_fraction=.1):
176
- indices = list(range(size))
177
- random.shuffle(indices)
178
- train_split = int(size * train_fraction)
179
- val_split = train_split + int(size * val_fraction)
180
- return indices[:train_split], indices[train_split: val_split], indices[val_split:]
181
 
182
  def train_dataloader(self):
183
- return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
184
- pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS,
185
- shuffle=True)
186
 
187
  def val_dataloader(self):
188
- return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
189
- pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
190
 
191
  def test_dataloader(self):
192
- return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
193
- pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
 
1
+ from constants import DATA_DIR, TOKENIZER_PATH, NUM_DATALOADER_WORKERS, PERSISTENT_WORKERS, PIN_MEMORY
2
 
3
  import einops
4
  import os
 
9
  import torchvision.transforms as T
10
  from torch.utils.data import Dataset, DataLoader
11
  import tqdm
 
12
  import re
13
 
 
 
 
 
 
14
 
15
  class TexImageDataset(Dataset):
16
  """Image and tex dataset."""
 
141
  class LatexImageDataModule(pl.LightningDataModule):
142
  def __init__(self, image_width, image_height, batch_size, random_magnitude):
143
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ dataset = TexImageDataset(root_dir=DATA_DIR,
146
+ image_transform=RandomizeImageTransform(image_width, image_height, random_magnitude),
147
+ tex_transform=ExtractEquationFromTexTransform())
148
+ self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
149
+ dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
150
  self.batch_size = batch_size
151
  self.save_hyperparameters()
152
 
 
155
  print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
156
  torch.save(tokenizer, TOKENIZER_PATH)
157
 
158
+ def _shared_dataloader(self, dataset, **kwargs):
159
+ tex_tokenizer = torch.load(TOKENIZER_PATH)
160
+ collate_fn = BatchCollator(tex_tokenizer)
161
+ return DataLoader(dataset, batch_size=self.batch_size, collate_fn=collate_fn, pin_memory=PIN_MEMORY,
162
+ num_workers=NUM_DATALOADER_WORKERS, persistent_workers=PERSISTENT_WORKERS, **kwargs)
 
 
 
 
 
 
163
 
164
  def train_dataloader(self):
165
+ return self._shared_dataloader(self.train_dataset, shuffle=True)
 
 
166
 
167
  def val_dataloader(self):
168
+ return self._shared_dataloader(self.val_dataset)
 
169
 
170
  def test_dataloader(self):
171
+ return self._shared_dataloader(self.test_dataset)
 
model.py CHANGED
@@ -30,7 +30,6 @@ class AddPositionalEncoding(nn.Module):
30
  def forward(self, batch):
31
  seq_len = batch.size(1)
32
  positional_encodings = self.positional_encodings[:seq_len, :]
33
- # implicit batch broadcasting
34
  return batch + positional_encodings
35
 
36
 
@@ -125,7 +124,6 @@ class Transformer(pl.LightningModule):
125
  self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
126
  self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
127
  self.generator = nn.Linear(d_model, tgt_vocab_size)
128
- # Make embedding and generator share weight because they do the same thing
129
  self.tgt_tok_emb.embedding.weight = self.generator.weight
130
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
131
  self.save_hyperparameters()
 
30
  def forward(self, batch):
31
  seq_len = batch.size(1)
32
  positional_encodings = self.positional_encodings[:seq_len, :]
 
33
  return batch + positional_encodings
34
 
35
 
 
124
  self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
125
  self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
126
  self.generator = nn.Linear(d_model, tgt_vocab_size)
 
127
  self.tgt_tok_emb.embedding.weight = self.generator.weight
128
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
129
  self.save_hyperparameters()
train.py CHANGED
@@ -1,16 +1,15 @@
 
1
  from data_preprocessing import LatexImageDataModule
2
  from model import Transformer
3
  from utils import LogImageTexCallback
4
 
5
  import argparse
6
  import os
7
- from pytorch_lightning.callbacks import LearningRateMonitor
8
  from pytorch_lightning.loggers import WandbLogger
9
  from pytorch_lightning import Trainer
10
  import torch
11
 
12
- TRAINER_DIR = "resources/pl_trainer_checkpoints"
13
-
14
 
15
  # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
16
  # crop image, adjust brightness, make tex tokens always decodable,
@@ -47,11 +46,15 @@ def main():
47
  datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
48
  batch_size=args.batch_size, random_magnitude=args.random_magnitude)
49
  datamodule.prepare_data()
50
- datamodule.setup()
51
  if args.log:
52
  logger = WandbLogger(f"img2tex", log_model=True)
53
- callbacks = [LogImageTexCallback(logger, datamodule.tex_tokenizer),
54
- LearningRateMonitor(logging_interval='step')]
 
 
 
 
55
  else:
56
  logger = None
57
  callbacks = []
@@ -63,7 +66,8 @@ def main():
63
  strategy="ddp",
64
  enable_progress_bar=True,
65
  default_root_dir=TRAINER_DIR,
66
- callbacks=callbacks)
 
67
 
68
  transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
69
  num_decoder_layers=args.transformer_args['num_decoder_layers'],
@@ -77,7 +81,6 @@ def main():
77
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
78
 
79
  trainer.fit(transformer, datamodule=datamodule)
80
- trainer.test(datamodule=datamodule, ckpt_path='best')
81
  trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
82
 
83
 
 
1
+ from constants import TRAINER_DIR
2
  from data_preprocessing import LatexImageDataModule
3
  from model import Transformer
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
7
  import os
8
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
9
  from pytorch_lightning.loggers import WandbLogger
10
  from pytorch_lightning import Trainer
11
  import torch
12
 
 
 
13
 
14
  # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
15
  # crop image, adjust brightness, make tex tokens always decodable,
 
46
  datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
47
  batch_size=args.batch_size, random_magnitude=args.random_magnitude)
48
  datamodule.prepare_data()
49
+
50
  if args.log:
51
  logger = WandbLogger(f"img2tex", log_model=True)
52
+ callbacks = [LogImageTexCallback(logger),
53
+ LearningRateMonitor(logging_interval='step'),
54
+ ModelCheckpoint(save_top_k=10,
55
+ monitor="val_loss",
56
+ mode="min",
57
+ filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
58
  else:
59
  logger = None
60
  callbacks = []
 
66
  strategy="ddp",
67
  enable_progress_bar=True,
68
  default_root_dir=TRAINER_DIR,
69
+ callbacks=callbacks,
70
+ check_val_every_n_epoch=5)
71
 
72
  transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
73
  num_decoder_layers=args.transformer_args['num_decoder_layers'],
 
81
  pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
82
 
83
  trainer.fit(transformer, datamodule=datamodule)
 
84
  trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
85
 
86
 
utils.py CHANGED
@@ -1,4 +1,4 @@
1
- from model import Transformer
2
 
3
  import einops
4
  import random
@@ -8,9 +8,9 @@ from torchvision import transforms
8
 
9
 
10
  class LogImageTexCallback(Callback):
11
- def __init__(self, logger, tex_tokenizer):
12
  self.logger = logger
13
- self.tex_tokenizer = tex_tokenizer
14
  self.tensor_to_PIL = transforms.ToPILImage()
15
 
16
  def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
@@ -102,7 +102,8 @@ class LogImageTexCallback(Callback):
102
 
103
 
104
  @torch.inference_mode()
105
- def decode(transformer, tex_tokenizer, image):
 
106
  tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
107
  src = einops.rearrange(image, "c h w -> () c h w")
108
  while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
 
1
+ from constants import TOKENIZER_PATH
2
 
3
  import einops
4
  import random
 
8
 
9
 
10
  class LogImageTexCallback(Callback):
11
+ def __init__(self, logger):
12
  self.logger = logger
13
+ self.tex_tokenizer = torch.load(TOKENIZER_PATH)
14
  self.tensor_to_PIL = transforms.ToPILImage()
15
 
16
  def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
 
102
 
103
 
104
  @torch.inference_mode()
105
+ def decode(transformer, image):
106
+ tex_tokenizer = torch.load(TOKENIZER_PATH)
107
  tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
108
  src = einops.rearrange(image, "c h w -> () c h w")
109
  while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30: