Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
c308f77
1
Parent(s):
41a34cd
moved constants to separate file, organized tokenizer
Browse files- constants.py +11 -0
- data_generator.py +2 -5
- data_preprocessing.py +14 -36
- model.py +0 -2
- train.py +11 -8
- utils.py +5 -4
constants.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
|
2 |
+
GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
3 |
+
|
4 |
+
DATA_DIR = "data"
|
5 |
+
LATEX_PATH = "resources/latex.json"
|
6 |
+
TRAINER_DIR = "resources/trainer"
|
7 |
+
TOKENIZER_PATH = "resources/tokenizer.pt"
|
8 |
+
|
9 |
+
NUM_DATALOADER_WORKERS = 4
|
10 |
+
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
11 |
+
PIN_MEMORY = False # probably causes cuda oom error if True
|
data_generator.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import json
|
2 |
from multiprocessing import Pool
|
3 |
import os
|
@@ -7,11 +9,6 @@ import subprocess
|
|
7 |
import random
|
8 |
import tqdm
|
9 |
|
10 |
-
DATA_DIR = "data"
|
11 |
-
LATEX_PATH = "resources/latex.json"
|
12 |
-
PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
|
13 |
-
GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
14 |
-
|
15 |
|
16 |
def generate_equation(latex, size, max_depth):
|
17 |
"""
|
|
|
1 |
+
from constants import DATA_DIR, LATEX_PATH, PDFLATEX, GHOSTSCRIPT
|
2 |
+
|
3 |
import json
|
4 |
from multiprocessing import Pool
|
5 |
import os
|
|
|
9 |
import random
|
10 |
import tqdm
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def generate_equation(latex, size, max_depth):
|
14 |
"""
|
data_preprocessing.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from
|
2 |
|
3 |
import einops
|
4 |
import os
|
@@ -9,14 +9,8 @@ import torchvision
|
|
9 |
import torchvision.transforms as T
|
10 |
from torch.utils.data import Dataset, DataLoader
|
11 |
import tqdm
|
12 |
-
import random
|
13 |
import re
|
14 |
|
15 |
-
TOKENIZER_PATH = "resources/tokenizer.pt"
|
16 |
-
NUM_WORKERS = 4
|
17 |
-
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
18 |
-
PIN_MEMORY = False # probably causes cuda oom error if True
|
19 |
-
|
20 |
|
21 |
class TexImageDataset(Dataset):
|
22 |
"""Image and tex dataset."""
|
@@ -147,18 +141,12 @@ def generate_tex_tokenizer(dataloader):
|
|
147 |
class LatexImageDataModule(pl.LightningDataModule):
|
148 |
def __init__(self, image_width, image_height, batch_size, random_magnitude):
|
149 |
super().__init__()
|
150 |
-
image_transform = RandomizeImageTransform(image_width, image_height, random_magnitude)
|
151 |
-
tex_transform = ExtractEquationFromTexTransform()
|
152 |
-
|
153 |
-
self.train_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
154 |
-
self.val_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
155 |
-
self.test_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
156 |
-
|
157 |
-
train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
|
158 |
-
self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
|
159 |
-
self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
|
160 |
-
self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
|
161 |
|
|
|
|
|
|
|
|
|
|
|
162 |
self.batch_size = batch_size
|
163 |
self.save_hyperparameters()
|
164 |
|
@@ -167,27 +155,17 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
167 |
print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
|
168 |
torch.save(tokenizer, TOKENIZER_PATH)
|
169 |
|
170 |
-
def
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
def train_val_test_split(size, train_fraction=.8, val_fraction=.1):
|
176 |
-
indices = list(range(size))
|
177 |
-
random.shuffle(indices)
|
178 |
-
train_split = int(size * train_fraction)
|
179 |
-
val_split = train_split + int(size * val_fraction)
|
180 |
-
return indices[:train_split], indices[train_split: val_split], indices[val_split:]
|
181 |
|
182 |
def train_dataloader(self):
|
183 |
-
return
|
184 |
-
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS,
|
185 |
-
shuffle=True)
|
186 |
|
187 |
def val_dataloader(self):
|
188 |
-
return
|
189 |
-
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
190 |
|
191 |
def test_dataloader(self):
|
192 |
-
return
|
193 |
-
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
|
|
|
1 |
+
from constants import DATA_DIR, TOKENIZER_PATH, NUM_DATALOADER_WORKERS, PERSISTENT_WORKERS, PIN_MEMORY
|
2 |
|
3 |
import einops
|
4 |
import os
|
|
|
9 |
import torchvision.transforms as T
|
10 |
from torch.utils.data import Dataset, DataLoader
|
11 |
import tqdm
|
|
|
12 |
import re
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class TexImageDataset(Dataset):
|
16 |
"""Image and tex dataset."""
|
|
|
141 |
class LatexImageDataModule(pl.LightningDataModule):
|
142 |
def __init__(self, image_width, image_height, batch_size, random_magnitude):
|
143 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
+
dataset = TexImageDataset(root_dir=DATA_DIR,
|
146 |
+
image_transform=RandomizeImageTransform(image_width, image_height, random_magnitude),
|
147 |
+
tex_transform=ExtractEquationFromTexTransform())
|
148 |
+
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
149 |
+
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
|
150 |
self.batch_size = batch_size
|
151 |
self.save_hyperparameters()
|
152 |
|
|
|
155 |
print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
|
156 |
torch.save(tokenizer, TOKENIZER_PATH)
|
157 |
|
158 |
+
def _shared_dataloader(self, dataset, **kwargs):
|
159 |
+
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
160 |
+
collate_fn = BatchCollator(tex_tokenizer)
|
161 |
+
return DataLoader(dataset, batch_size=self.batch_size, collate_fn=collate_fn, pin_memory=PIN_MEMORY,
|
162 |
+
num_workers=NUM_DATALOADER_WORKERS, persistent_workers=PERSISTENT_WORKERS, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
def train_dataloader(self):
|
165 |
+
return self._shared_dataloader(self.train_dataset, shuffle=True)
|
|
|
|
|
166 |
|
167 |
def val_dataloader(self):
|
168 |
+
return self._shared_dataloader(self.val_dataset)
|
|
|
169 |
|
170 |
def test_dataloader(self):
|
171 |
+
return self._shared_dataloader(self.test_dataset)
|
|
model.py
CHANGED
@@ -30,7 +30,6 @@ class AddPositionalEncoding(nn.Module):
|
|
30 |
def forward(self, batch):
|
31 |
seq_len = batch.size(1)
|
32 |
positional_encodings = self.positional_encodings[:seq_len, :]
|
33 |
-
# implicit batch broadcasting
|
34 |
return batch + positional_encodings
|
35 |
|
36 |
|
@@ -125,7 +124,6 @@ class Transformer(pl.LightningModule):
|
|
125 |
self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
|
126 |
self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
|
127 |
self.generator = nn.Linear(d_model, tgt_vocab_size)
|
128 |
-
# Make embedding and generator share weight because they do the same thing
|
129 |
self.tgt_tok_emb.embedding.weight = self.generator.weight
|
130 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
|
131 |
self.save_hyperparameters()
|
|
|
30 |
def forward(self, batch):
|
31 |
seq_len = batch.size(1)
|
32 |
positional_encodings = self.positional_encodings[:seq_len, :]
|
|
|
33 |
return batch + positional_encodings
|
34 |
|
35 |
|
|
|
124 |
self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
|
125 |
self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
|
126 |
self.generator = nn.Linear(d_model, tgt_vocab_size)
|
|
|
127 |
self.tgt_tok_emb.embedding.weight = self.generator.weight
|
128 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
|
129 |
self.save_hyperparameters()
|
train.py
CHANGED
@@ -1,16 +1,15 @@
|
|
|
|
1 |
from data_preprocessing import LatexImageDataModule
|
2 |
from model import Transformer
|
3 |
from utils import LogImageTexCallback
|
4 |
|
5 |
import argparse
|
6 |
import os
|
7 |
-
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
from pytorch_lightning.loggers import WandbLogger
|
9 |
from pytorch_lightning import Trainer
|
10 |
import torch
|
11 |
|
12 |
-
TRAINER_DIR = "resources/pl_trainer_checkpoints"
|
13 |
-
|
14 |
|
15 |
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
|
16 |
# crop image, adjust brightness, make tex tokens always decodable,
|
@@ -47,11 +46,15 @@ def main():
|
|
47 |
datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
|
48 |
batch_size=args.batch_size, random_magnitude=args.random_magnitude)
|
49 |
datamodule.prepare_data()
|
50 |
-
|
51 |
if args.log:
|
52 |
logger = WandbLogger(f"img2tex", log_model=True)
|
53 |
-
callbacks = [LogImageTexCallback(logger
|
54 |
-
LearningRateMonitor(logging_interval='step')
|
|
|
|
|
|
|
|
|
55 |
else:
|
56 |
logger = None
|
57 |
callbacks = []
|
@@ -63,7 +66,8 @@ def main():
|
|
63 |
strategy="ddp",
|
64 |
enable_progress_bar=True,
|
65 |
default_root_dir=TRAINER_DIR,
|
66 |
-
callbacks=callbacks
|
|
|
67 |
|
68 |
transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
|
69 |
num_decoder_layers=args.transformer_args['num_decoder_layers'],
|
@@ -77,7 +81,6 @@ def main():
|
|
77 |
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
|
78 |
|
79 |
trainer.fit(transformer, datamodule=datamodule)
|
80 |
-
trainer.test(datamodule=datamodule, ckpt_path='best')
|
81 |
trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
|
82 |
|
83 |
|
|
|
1 |
+
from constants import TRAINER_DIR
|
2 |
from data_preprocessing import LatexImageDataModule
|
3 |
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
5 |
|
6 |
import argparse
|
7 |
import os
|
8 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
9 |
from pytorch_lightning.loggers import WandbLogger
|
10 |
from pytorch_lightning import Trainer
|
11 |
import torch
|
12 |
|
|
|
|
|
13 |
|
14 |
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
|
15 |
# crop image, adjust brightness, make tex tokens always decodable,
|
|
|
46 |
datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
|
47 |
batch_size=args.batch_size, random_magnitude=args.random_magnitude)
|
48 |
datamodule.prepare_data()
|
49 |
+
|
50 |
if args.log:
|
51 |
logger = WandbLogger(f"img2tex", log_model=True)
|
52 |
+
callbacks = [LogImageTexCallback(logger),
|
53 |
+
LearningRateMonitor(logging_interval='step'),
|
54 |
+
ModelCheckpoint(save_top_k=10,
|
55 |
+
monitor="val_loss",
|
56 |
+
mode="min",
|
57 |
+
filename="img2tex-{epoch:02d}-{val_loss:.2f}")]
|
58 |
else:
|
59 |
logger = None
|
60 |
callbacks = []
|
|
|
66 |
strategy="ddp",
|
67 |
enable_progress_bar=True,
|
68 |
default_root_dir=TRAINER_DIR,
|
69 |
+
callbacks=callbacks,
|
70 |
+
check_val_every_n_epoch=5)
|
71 |
|
72 |
transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
|
73 |
num_decoder_layers=args.transformer_args['num_decoder_layers'],
|
|
|
81 |
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
|
82 |
|
83 |
trainer.fit(transformer, datamodule=datamodule)
|
|
|
84 |
trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
|
85 |
|
86 |
|
utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from
|
2 |
|
3 |
import einops
|
4 |
import random
|
@@ -8,9 +8,9 @@ from torchvision import transforms
|
|
8 |
|
9 |
|
10 |
class LogImageTexCallback(Callback):
|
11 |
-
def __init__(self, logger
|
12 |
self.logger = logger
|
13 |
-
self.tex_tokenizer =
|
14 |
self.tensor_to_PIL = transforms.ToPILImage()
|
15 |
|
16 |
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
|
@@ -102,7 +102,8 @@ class LogImageTexCallback(Callback):
|
|
102 |
|
103 |
|
104 |
@torch.inference_mode()
|
105 |
-
def decode(transformer,
|
|
|
106 |
tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
|
107 |
src = einops.rearrange(image, "c h w -> () c h w")
|
108 |
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
|
|
|
1 |
+
from constants import TOKENIZER_PATH
|
2 |
|
3 |
import einops
|
4 |
import random
|
|
|
8 |
|
9 |
|
10 |
class LogImageTexCallback(Callback):
|
11 |
+
def __init__(self, logger):
|
12 |
self.logger = logger
|
13 |
+
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
|
14 |
self.tensor_to_PIL = transforms.ToPILImage()
|
15 |
|
16 |
def on_validation_batch_start(self, trainer, transformer, batch, batch_idx, dataloader_idx):
|
|
|
102 |
|
103 |
|
104 |
@torch.inference_mode()
|
105 |
+
def decode(transformer, image):
|
106 |
+
tex_tokenizer = torch.load(TOKENIZER_PATH)
|
107 |
tex_ids = [tex_tokenizer.token_to_id("[CLS]")]
|
108 |
src = einops.rearrange(image, "c h w -> () c h w")
|
109 |
while tex_ids[-1] != tex_tokenizer.token_to_id("[SEP]") and len(tex_ids) < 30:
|