dkoshman commited on
Commit
2a394f6
1 Parent(s): fb8db0f

working backend

Browse files
Files changed (4) hide show
  1. data_generator.py +4 -3
  2. data_preprocessing.py +46 -23
  3. model.py +25 -4
  4. train.py +9 -14
data_generator.py CHANGED
@@ -1,5 +1,3 @@
1
- from train import DATA_DIR, LATEX_PATH
2
-
3
  import json
4
  from multiprocessing import Pool
5
  import os
@@ -9,6 +7,9 @@ import subprocess
9
  import random
10
  import tqdm
11
 
 
 
 
12
 
13
  class DotDict(dict):
14
  """dot.notation access to dictionary attributes"""
@@ -168,7 +169,7 @@ def generate_data(examples_count) -> None:
168
  :examples_count: - how many latex - image examples to generate
169
  """
170
 
171
- filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count)),
172
  directory = os.path.abspath(DATA_DIR)
173
  latex_path = os.path.abspath(LATEX_PATH)
174
  with open(latex_path) as file:
 
 
 
1
  import json
2
  from multiprocessing import Pool
3
  import os
 
7
  import random
8
  import tqdm
9
 
10
+ DATA_DIR = 'data'
11
+ LATEX_PATH = 'resources/latex.json'
12
+
13
 
14
  class DotDict(dict):
15
  """dot.notation access to dictionary attributes"""
 
169
  :examples_count: - how many latex - image examples to generate
170
  """
171
 
172
+ filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count))
173
  directory = os.path.abspath(DATA_DIR)
174
  latex_path = os.path.abspath(LATEX_PATH)
175
  with open(latex_path) as file:
data_preprocessing.py CHANGED
@@ -1,4 +1,4 @@
1
- from train import DATASET_PATH, DATA_DIR, BATCH_SIZE, TEX_VOCAB_SIZE
2
 
3
  import einops
4
  import os
@@ -9,9 +9,14 @@ import torchvision
9
  import torchvision.transforms as T
10
  from torch.utils.data import Dataset, DataLoader
11
  import tqdm
12
- from typing import Optional
13
  import re
14
 
 
 
 
 
 
15
 
16
  class TexImageDataset(Dataset):
17
  """Image and tex dataset."""
@@ -89,7 +94,7 @@ class BatchCollator(object):
89
  class StandardizeImageTransform(object):
90
  """Pad and crop image to a given size, grayscale and invert"""
91
 
92
- def __init__(self, width=1024, height=128):
93
  self.standardize = T.Compose((
94
  T.Resize(height),
95
  T.Grayscale(),
@@ -106,7 +111,7 @@ class StandardizeImageTransform(object):
106
  class RandomizeImageTransform(object):
107
  """Standardize image and randomly augment"""
108
 
109
- def __init__(self, width=1024, height=128, random_magnitude=5):
110
  self.transform = T.Compose((
111
  T.ColorJitter(brightness=random_magnitude / 10),
112
  T.Resize(height),
@@ -138,10 +143,10 @@ class ExtractEquationFromTexTransform(object):
138
  return equation
139
 
140
 
141
- def generate_tex_tokenizer(dataset: TexImageDataset, vocab_size=300):
142
  """Returns a tokenizer trained on texs from given dataset"""
143
 
144
- texs = list(tqdm.tqdm((item['tex'] for item in dataset), "Training tokenizer"))
145
 
146
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
147
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
@@ -164,31 +169,49 @@ def generate_tex_tokenizer(dataset: TexImageDataset, vocab_size=300):
164
 
165
 
166
  class LatexImageDataModule(pl.LightningDataModule):
167
- def prepare_data(self) -> None:
168
- # download or something
169
- ...
170
-
171
- def setup(self, stage: Optional[str] = None) -> None:
172
- tex_transform = ExtractEquationFromTexTransform()
173
- dataset = TexImageDataset(DATA_DIR, tex_transform=tex_transform)
174
 
175
- self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
176
- dataset,
177
- [len(dataset) - 2 * len(dataset) // 10, len(dataset) // 10, len(dataset) // 10]
 
178
  )
179
- self.train_dataset.image_transform = RandomizeImageTransform()
180
- self.val_dataset.image_transform = StandardizeImageTransform()
181
- self.test_dataset.image_transform = StandardizeImageTransform()
182
- # image_normalize = generate_normalize_transform(self.train_dataset), compose?
 
 
 
 
 
 
 
 
 
 
183
 
184
  self.tex_tokenizer = generate_tex_tokenizer(self.train_dataset, vocab_size=TEX_VOCAB_SIZE)
185
  self.collate_fn = BatchCollator(self.tex_tokenizer)
186
 
 
 
 
 
 
 
 
 
187
  def train_dataloader(self):
188
- return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=self.collate_fn)
 
189
 
190
  def val_dataloader(self):
191
- return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
 
192
 
193
  def test_dataloader(self):
194
- return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=self.collate_fn)
 
 
1
+ from data_generator import DATA_DIR
2
 
3
  import einops
4
  import os
 
9
  import torchvision.transforms as T
10
  from torch.utils.data import Dataset, DataLoader
11
  import tqdm
12
+ import random
13
  import re
14
 
15
+ TEX_VOCAB_SIZE = 300
16
+ BATCH_SIZE = 16
17
+ IMAGE_WIDTH = 1024
18
+ IMAGE_HEIGHT = 128
19
+
20
 
21
  class TexImageDataset(Dataset):
22
  """Image and tex dataset."""
 
94
  class StandardizeImageTransform(object):
95
  """Pad and crop image to a given size, grayscale and invert"""
96
 
97
+ def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT):
98
  self.standardize = T.Compose((
99
  T.Resize(height),
100
  T.Grayscale(),
 
111
  class RandomizeImageTransform(object):
112
  """Standardize image and randomly augment"""
113
 
114
+ def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
115
  self.transform = T.Compose((
116
  T.ColorJitter(brightness=random_magnitude / 10),
117
  T.Resize(height),
 
143
  return equation
144
 
145
 
146
+ def generate_tex_tokenizer(dataset, vocab_size):
147
  """Returns a tokenizer trained on texs from given dataset"""
148
 
149
+ texs = list(tqdm.tqdm((item['tex'] for item in dataset), "Training tokenizer", total=len(dataset)))
150
 
151
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
152
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
 
169
 
170
 
171
  class LatexImageDataModule(pl.LightningDataModule):
172
+ def __init__(self):
173
+ super().__init__()
174
+ torch.manual_seed(0)
 
 
 
 
175
 
176
+ self.train_dataset = TexImageDataset(
177
+ root_dir=DATA_DIR,
178
+ image_transform=RandomizeImageTransform(),
179
+ tex_transform=ExtractEquationFromTexTransform()
180
  )
181
+ self.val_dataset = TexImageDataset(
182
+ root_dir=DATA_DIR,
183
+ image_transform=StandardizeImageTransform(),
184
+ tex_transform=ExtractEquationFromTexTransform()
185
+ )
186
+ self.test_dataset = TexImageDataset(
187
+ root_dir=DATA_DIR,
188
+ image_transform=StandardizeImageTransform(),
189
+ tex_transform=ExtractEquationFromTexTransform()
190
+ )
191
+ train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
192
+ self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
193
+ self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
194
+ self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
195
 
196
  self.tex_tokenizer = generate_tex_tokenizer(self.train_dataset, vocab_size=TEX_VOCAB_SIZE)
197
  self.collate_fn = BatchCollator(self.tex_tokenizer)
198
 
199
+ @staticmethod
200
+ def train_val_test_split(size, train_fraction=.8, val_fraction=.1):
201
+ indices = list(range(size))
202
+ random.shuffle(indices)
203
+ train_split = int(size * train_fraction)
204
+ val_split = train_split + int(size * val_fraction)
205
+ return indices[:train_split], indices[train_split: val_split], indices[val_split:]
206
+
207
  def train_dataloader(self):
208
+ return DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=self.collate_fn,
209
+ num_workers=8, pin_memory=True, persistent_workers=False)
210
 
211
  def val_dataloader(self):
212
+ return DataLoader(self.val_dataset, batch_size=BATCH_SIZE, collate_fn=self.collate_fn, num_workers=8,
213
+ pin_memory=True, persistent_workers=False)
214
 
215
  def test_dataloader(self):
216
+ return DataLoader(self.test_dataset, batch_size=BATCH_SIZE, collate_fn=self.collate_fn, num_workers=8,
217
+ pin_memory=True, persistent_workers=False)
model.py CHANGED
@@ -2,7 +2,6 @@ from einops.layers.torch import Rearrange
2
  import einops
3
  import math
4
  import pytorch_lightning as pl
5
- from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
6
  import torch.nn as nn
7
  import torch
8
 
@@ -101,9 +100,6 @@ class ImageEncoder(nn.Module):
101
 
102
 
103
  class Transformer(pl.LightningModule):
104
- def train_dataloader(self) -> TRAIN_DATALOADERS:
105
- pass
106
-
107
  def __init__(self,
108
  num_encoder_layers: int,
109
  num_decoder_layers: int,
@@ -139,6 +135,20 @@ class Transformer(pl.LightningModule):
139
  src_padding_mask, tgt_padding_mask)
140
  return self.generator(outs)
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def training_step(self, batch, batch_idx):
143
  src = batch['images']
144
  tgt = batch['tex_ids']
@@ -154,5 +164,16 @@ class Transformer(pl.LightningModule):
154
  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
155
  return loss
156
 
 
 
 
 
 
 
 
 
 
 
157
  def configure_optimizers(self):
 
158
  return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
 
2
  import einops
3
  import math
4
  import pytorch_lightning as pl
 
5
  import torch.nn as nn
6
  import torch
7
 
 
100
 
101
 
102
  class Transformer(pl.LightningModule):
 
 
 
103
  def __init__(self,
104
  num_encoder_layers: int,
105
  num_decoder_layers: int,
 
135
  src_padding_mask, tgt_padding_mask)
136
  return self.generator(outs)
137
 
138
+ def general_step(self, batch):
139
+ src = batch['images']
140
+ tgt = batch['tex_ids']
141
+ tgt_input = tgt[:, :-1]
142
+ tgt_output = tgt[:, 1:]
143
+ src_mask = None
144
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
145
+ torch.ByteTensor.dtype)
146
+ src_padding_mask = None
147
+ tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
148
+ outs = self(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask)
149
+ loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
150
+ return loss
151
+
152
  def training_step(self, batch, batch_idx):
153
  src = batch['images']
154
  tgt = batch['tex_ids']
 
164
  self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
165
  return loss
166
 
167
+ def validation_step(self, batch, batch_idx):
168
+ loss = self.general_step(batch)
169
+ self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
170
+ return loss
171
+
172
+ def test_step(self, batch, batch_idx):
173
+ loss = self.general_step(batch)
174
+ self.log("test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
175
+ return loss
176
+
177
  def configure_optimizers(self):
178
+ # TODO write scheduler
179
  return torch.optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
train.py CHANGED
@@ -1,5 +1,5 @@
1
  from data_generator import generate_data
2
- from data_preprocessing import LatexImageDataModule
3
  from model import Transformer
4
 
5
  import argparse
@@ -7,31 +7,26 @@ import pytorch_lightning as pl
7
  from pytorch_lightning.loggers import WandbLogger
8
  import torch
9
 
10
- DATA_DIR = 'data'
11
- LATEX_PATH = 'resources/latex.json'
12
- DATASET_PATH = 'resources/dataset'
13
- IMAGE_WIDTH = 1024
14
- IMAGE_HEIGHT = 128
15
- TEX_VOCAB_SIZE = 300
16
- BATCH_SIZE = 16
17
 
18
 
19
  def main():
20
- torch.manual_seed(0)
21
-
22
  parser = argparse.ArgumentParser("Trainer")
23
- parser.add_argument("-generate-new", help="number of new files to generate", type=int)
 
 
 
24
  args = parser.parse_args()
25
 
26
- if args.generate_new is not None:
27
- generate_data(args.generate_new)
28
  datamodule = LatexImageDataModule()
29
  torch.save(datamodule, DATASET_PATH)
30
  else:
31
  datamodule = torch.load(DATASET_PATH)
32
 
33
  wandb_logger = WandbLogger()
34
- trainer = pl.Trainer(max_epochs=2, accelerator='gpu', gpus=1, logger=wandb_logger)
35
  transformer = Transformer(
36
  num_encoder_layers=3,
37
  num_decoder_layers=3,
 
1
  from data_generator import generate_data
2
+ from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
  from model import Transformer
4
 
5
  import argparse
 
7
  from pytorch_lightning.loggers import WandbLogger
8
  import torch
9
 
10
+ DATASET_PATH = 'resources/dataset.pt'
 
 
 
 
 
 
11
 
12
 
13
  def main():
 
 
14
  parser = argparse.ArgumentParser("Trainer")
15
+ parser.add_argument("-n", "-new-dataset", help="clear old dataset and generate provided number of new examples",
16
+ type=int, dest="new_dataset")
17
+ parser.add_argument("-g", "-gpus", help="list of gpu ids to train on", type=int, nargs='+', dest="gpus",
18
+ choices=list(range(torch.cuda.device_count())), default=[0])
19
  args = parser.parse_args()
20
 
21
+ if args.new_dataset is not None:
22
+ generate_data(args.new_dataset)
23
  datamodule = LatexImageDataModule()
24
  torch.save(datamodule, DATASET_PATH)
25
  else:
26
  datamodule = torch.load(DATASET_PATH)
27
 
28
  wandb_logger = WandbLogger()
29
+ trainer = pl.Trainer(max_epochs=2, accelerator='gpu', gpus=args.gpus, logger=wandb_logger, strategy='ddp_spawn')
30
  transformer = Transformer(
31
  num_encoder_layers=3,
32
  num_decoder_layers=3,