dkoshman commited on
Commit
41a34cd
1 Parent(s): e932abd

dedicated generate.py script

Browse files
Files changed (6) hide show
  1. data_generator.py +100 -135
  2. data_preprocessing.py +43 -75
  3. generate.py +23 -0
  4. model.py +6 -10
  5. train.py +46 -64
  6. utils.py +75 -48
data_generator.py CHANGED
@@ -7,109 +7,76 @@ import subprocess
7
  import random
8
  import tqdm
9
 
10
- DATA_DIR = 'data'
11
- LATEX_PATH = 'resources/latex.json'
12
-
13
-
14
- class DotDict(dict):
15
- """dot.notation access to dictionary attributes"""
16
- __getattr__ = dict.get
17
- __setattr__ = dict.__setitem__
18
- __delattr__ = dict.__delitem__
19
-
20
- def __init__(self, *args, **kwargs):
21
- super().__init__(*args, **kwargs)
22
- if len(args) > 0 and isinstance(args[0], dict):
23
- for key, value in self.items():
24
- if isinstance(value, dict):
25
- self.__setitem__(key, DotDict(value))
26
-
27
-
28
- def _generate_equation(size_left, depth_left, latex, tokens):
29
- if size_left <= 0:
30
- return ""
31
-
32
- equation = ""
33
- pairs, scopes, special = latex.pairs, latex.scopes, latex.special
34
- weights = [3, depth_left > 0, depth_left > 0]
35
- group, = random.choices([tokens, pairs, scopes], weights=weights)
36
-
37
- if group is tokens:
38
- equation += ' '.join([
39
- random.choice(tokens),
40
- _generate_equation(size_left - 1, depth_left, latex, tokens)
41
- ])
42
- return equation
43
-
44
- post_scope_size = round(abs(random.gauss(0, size_left / 2)))
45
- size_left -= post_scope_size + 1
46
-
47
- if group is pairs:
48
- pair = random.choice(pairs)
49
- equation += ' '.join([
50
- pair[0],
51
- _generate_equation(size_left, depth_left - 1, latex, tokens),
52
- pair[1],
53
- _generate_equation(post_scope_size, depth_left, latex, tokens)
54
- ])
55
- return equation
56
 
57
- elif group is scopes:
58
- scope_type, scope_group = random.choice(list(scopes.items()))
59
- scope_operator = random.choice(scope_group)
60
- equation += scope_operator
61
 
62
- if scope_type == 'single':
63
- equation += ' '.join([
64
- special.left_bracket,
65
- _generate_equation(size_left, depth_left - 1, latex, tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  ])
67
-
68
- elif scope_type == 'double_no_delimiters':
69
- equation += ' '.join([
70
- special.left_bracket,
71
- _generate_equation(size_left // 2, depth_left - 1, latex, tokens),
72
- special.right_bracket + special.left_bracket,
73
- _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
 
 
 
 
 
74
  ])
 
75
 
76
- elif scope_type == 'double_with_delimiters':
77
- equation += ' '.join([
78
- special.caret,
79
- special.left_bracket,
80
- _generate_equation(size_left // 2, depth_left - 1, latex, tokens),
81
- special.right_bracket,
82
- special.underscore,
83
- special.left_bracket,
84
- _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
85
- ])
86
 
87
- equation += ' '.join([
88
- special.right_bracket,
89
- _generate_equation(post_scope_size, depth_left, latex, tokens)
90
- ])
91
- return equation
92
 
 
 
 
93
 
94
- def generate_equation(latex: DotDict, size, depth=3):
95
- """
96
- Generates a random latex equation
97
- -------
98
- params:
99
- :latex: -- dict with tokens to generate equation from
100
- :size: -- approximate size of equation
101
- :depth: -- max brackets and scope depth
102
- """
103
- tokens = [token for group in ['chars', 'greek', 'functions', 'operators', 'spaces']
104
- for token in latex[group]]
105
- equation = _generate_equation(size, depth, latex, tokens)
106
- return equation
107
 
 
108
 
109
- def generate_image(directory: str, latex: dict, filename: str, max_length=20, equation_depth=3,
110
- pdflatex: str = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex",
111
- ghostscript: str = "/external2/dkkoshman/venv/local/gs/bin/gs"
112
- ):
113
  """
114
  Generates a random tex file and corresponding image
115
  -------
@@ -117,41 +84,47 @@ def generate_image(directory: str, latex: dict, filename: str, max_length=20, eq
117
  :directory: -- dir where to save files
118
  :latex: -- dict with parameters to generate tex
119
  :filename: -- absolute filename for the generated files
120
- :max_length: -- max size of equation
121
- :equation_depth: -- max nested level of tex scopes
122
- :pdflatex: -- path to pdflatex
123
- :ghostscript: -- path to ghostscript
124
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  filepath = os.path.join(directory, filename)
126
- equation_length = random.randint(max_length // 2, max_length)
127
- latex = DotDict(latex)
128
- template = string.Template(latex.template)
129
- font, font_options = random.choice(latex.fonts)
130
- font_option = random.choice([''] + font_options)
131
- fontsize = random.choice(latex.fontsizes)
132
- equation = generate_equation(latex, equation_length, depth=equation_depth)
133
- tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
134
-
135
- with open(f"{filepath}.tex", mode='w') as file:
136
  file.write(tex)
137
 
138
  try:
139
  pdflatex_process = subprocess.run(
140
- f"{pdflatex} -output-directory={directory} {filepath}.tex".split(),
141
  stderr=subprocess.DEVNULL,
142
  stdout=subprocess.DEVNULL,
143
  timeout=1
144
  )
145
  except subprocess.TimeoutExpired:
146
- subprocess.run(f'rm {filepath}.tex'.split())
147
  return
148
 
149
  if pdflatex_process.returncode != 0:
150
- subprocess.run(f'rm {filepath}.tex'.split())
151
  return
152
 
153
  subprocess.run(
154
- f"{ghostscript} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
 
155
  stderr=subprocess.DEVNULL,
156
  stdout=subprocess.DEVNULL,
157
  )
@@ -161,41 +134,33 @@ def _generate_image_wrapper(args):
161
  return generate_image(*args)
162
 
163
 
164
- def generate_data(examples_count) -> None:
165
  """
166
  Clears a directory and generates a latex dataset in given directory
167
- -------
168
- params:
169
- :examples_count: - how many latex - image examples to generate
170
  """
171
 
172
- filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count))
173
  directory = os.path.abspath(DATA_DIR)
174
- latex_path = os.path.abspath(LATEX_PATH)
175
- with open(latex_path) as file:
176
- latex = json.load(file)
177
 
178
- shutil.rmtree(directory)
179
- os.mkdir(directory)
180
 
181
- def _get_current_relevant_files():
182
- return set(os.path.join(directory, file) for file in os.listdir(directory)) | set(
183
- os.path.abspath(file) for file in os.listdir(os.getcwd()))
184
 
185
- files_before = _get_current_relevant_files()
186
  while filenames:
187
  with Pool() as pool:
188
  list(tqdm.tqdm(
189
- pool.imap(_generate_image_wrapper, ((directory, latex, filename) for filename in sorted(filenames))),
 
 
190
  "Generating images",
191
  total=len(filenames)
192
  ))
193
- existing = set(os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith('.png'))
194
- filenames -= existing
195
-
196
- files_after = _get_current_relevant_files()
197
- files_to_delete = files_after - files_before
198
- files_to_delete = list(os.path.join(directory, file) for file in files_to_delete if
199
- not file.endswith('.png') and not file.endswith('.tex'))
200
- if files_to_delete:
201
- subprocess.run(['rm'] + files_to_delete)
 
7
  import random
8
  import tqdm
9
 
10
+ DATA_DIR = "data"
11
+ LATEX_PATH = "resources/latex.json"
12
+ PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
13
+ GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
15
 
16
+ def generate_equation(latex, size, max_depth):
17
+ """
18
+ Generates a random latex equation
19
+ -------
20
+ params:
21
+ :latex: -- dict with tokens to generate equation from
22
+ :size: -- approximate size of equation
23
+ :max_depth: -- max brackets and scope depth
24
+ """
25
+
26
+ tokens, pairs, scopes = latex["tokens"], latex["pairs"], latex["scope_manipulators"]
27
+
28
+ def _generate_equation_recursive(size_left=size, depth_used=0):
29
+ if size_left <= 0:
30
+ return ""
31
+
32
+ equation = ""
33
+ group, = random.choices([tokens, pairs, scopes],
34
+ weights=[max_depth + 1, max_depth > depth_used, max_depth > depth_used])
35
+
36
+ if group is tokens:
37
+ equation += " ".join([
38
+ random.choice(tokens),
39
+ _generate_equation_recursive(size_left - 1, depth_used)
40
  ])
41
+ return equation
42
+
43
+ post_scope_size = round(abs(random.gauss(0, size_left / 2)))
44
+ size_left -= post_scope_size + 1
45
+
46
+ if group is pairs:
47
+ pair = random.choice(pairs)
48
+ equation += " ".join([
49
+ pair[0],
50
+ _generate_equation_recursive(size_left, depth_used + 1),
51
+ pair[1],
52
+ _generate_equation_recursive(post_scope_size, depth_used)
53
  ])
54
+ return equation
55
 
56
+ elif group is scopes:
57
+ scope_type, scope_group = random.choice(list(scopes.items()))
58
+ scope_operator = random.choice(scope_group)
59
+ equation += scope_operator
 
 
 
 
 
 
60
 
61
+ if scope_type == "single":
62
+ equation += "{ " + _generate_equation_recursive(size_left, depth_used + 1)
 
 
 
63
 
64
+ elif scope_type == "double_no_delimiters":
65
+ equation += "{ " + _generate_equation_recursive(size_left // 2, depth_used + 1) + " } { " + \
66
+ _generate_equation_recursive(size_left // 2, depth_used + 1)
67
 
68
+ elif scope_type == "double_with_delimiters":
69
+ equation += "^ { " + _generate_equation_recursive(size_left // 2, depth_used + 1) + " } _ { " + \
70
+ _generate_equation_recursive(size_left // 2, depth_used + 1)
71
+
72
+ equation += _generate_equation_recursive(post_scope_size, depth_used) + " }"
 
 
 
 
 
 
 
 
73
 
74
+ return equation
75
 
76
+ return _generate_equation_recursive()
77
+
78
+
79
+ def generate_image(directory, latex, filename, max_depth, equation_length, distribution_fraction):
80
  """
81
  Generates a random tex file and corresponding image
82
  -------
 
84
  :directory: -- dir where to save files
85
  :latex: -- dict with parameters to generate tex
86
  :filename: -- absolute filename for the generated files
87
+ :max_depth: -- max nested level of tex scopes
88
+ :equation_length: -- max length of equation
89
+ :distribution_fraction: -- fraction of whole available tex tokens to use
 
90
  """
91
+ fracture = lambda sequence: sequence[:max(1, int(len(sequence) * distribution_fraction))]
92
+ for group in ["tokens", "pairs", "fonts", "font_sizes"]:
93
+ latex[group] = fracture(latex[group])
94
+ for key, value in list(latex["scope_manipulators"].items()):
95
+ latex["scope_manipulators"]['key'] = fracture(value)
96
+
97
+ size = random.randint((equation_length + 1) // 2, equation_length)
98
+ equation = generate_equation(latex, size=size, max_depth=max_depth)
99
+
100
+ font, font_options = random.choice(latex["fonts"])
101
+ font_option = random.choice([""] + font_options)
102
+ font_size = random.choice(latex["font_sizes"])
103
+ template = string.Template(latex["template"])
104
+ tex = template.substitute(font=font, font_option=font_option, fontsize=font_size, equation=equation)
105
+
106
  filepath = os.path.join(directory, filename)
107
+ with open(f"{filepath}.tex", mode="w") as file:
 
 
 
 
 
 
 
 
 
108
  file.write(tex)
109
 
110
  try:
111
  pdflatex_process = subprocess.run(
112
+ f"{PDFLATEX} -output-directory={directory} {filepath}.tex".split(),
113
  stderr=subprocess.DEVNULL,
114
  stdout=subprocess.DEVNULL,
115
  timeout=1
116
  )
117
  except subprocess.TimeoutExpired:
118
+ os.remove(filepath + ".tex")
119
  return
120
 
121
  if pdflatex_process.returncode != 0:
122
+ os.remove(filepath + ".tex")
123
  return
124
 
125
  subprocess.run(
126
+ f"{GHOSTSCRIPT} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE"
127
+ f" -o {filepath}.png {filepath}.pdf".split(),
128
  stderr=subprocess.DEVNULL,
129
  stdout=subprocess.DEVNULL,
130
  )
 
134
  return generate_image(*args)
135
 
136
 
137
+ def generate_data(examples_count, max_depth, equation_length, distribution_fraction) -> None:
138
  """
139
  Clears a directory and generates a latex dataset in given directory
 
 
 
140
  """
141
 
 
142
  directory = os.path.abspath(DATA_DIR)
143
+ shutil.rmtree(DATA_DIR)
144
+ os.mkdir(DATA_DIR)
 
145
 
146
+ with open(LATEX_PATH) as file:
147
+ latex = json.load(file)
148
 
149
+ filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count))
150
+ files_before = set(os.listdir())
 
151
 
 
152
  while filenames:
153
  with Pool() as pool:
154
  list(tqdm.tqdm(
155
+ pool.imap(_generate_image_wrapper,
156
+ ((directory, latex, filename, max_depth, equation_length, distribution_fraction) for filename
157
+ in sorted(filenames))),
158
  "Generating images",
159
  total=len(filenames)
160
  ))
161
+ filenames -= set(
162
+ os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith(".png"))
163
+
164
+ for file in set(i.path for i in os.scandir(DATA_DIR)) | set(os.listdir()) - files_before:
165
+ if any(file.endswith(ext) for ext in [".aux", ".pdf", ".log", ".sh"]):
166
+ os.remove(file)
 
 
 
data_preprocessing.py CHANGED
@@ -12,10 +12,7 @@ import tqdm
12
  import random
13
  import re
14
 
15
- TEX_VOCAB_SIZE = 300
16
- IMAGE_WIDTH = 1024
17
- IMAGE_HEIGHT = 128
18
- BATCH_SIZE = 16
19
  NUM_WORKERS = 4
20
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
21
  PIN_MEMORY = False # probably causes cuda oom error if True
@@ -60,22 +57,6 @@ class TexImageDataset(Dataset):
60
  return {"image": image, "tex": tex}
61
 
62
 
63
- def generate_normalize_transform(dataset: TexImageDataset):
64
- """Returns a normalize layer with mean and std computed after iterating over dataset"""
65
-
66
- mean = 0
67
- std = 0
68
- for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
69
- image = item['image']
70
- mean += image.mean()
71
- std += image.std()
72
-
73
- mean /= len(dataset)
74
- std /= len(dataset)
75
- normalize = T.Normalize(mean, std)
76
- return normalize
77
-
78
-
79
  class BatchCollator(object):
80
  """Image, tex batch collator"""
81
 
@@ -94,39 +75,30 @@ class BatchCollator(object):
94
  return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
95
 
96
 
97
- class StandardizeImageTransform(object):
98
- """Pad and crop image to a given size, grayscale and invert"""
99
-
100
- def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT):
101
- self.standardize = T.Compose((
102
- T.Resize(height),
103
- T.Grayscale(),
104
- T.functional.invert,
105
- T.CenterCrop((height, width)),
106
- T.ConvertImageDtype(torch.float32)
107
- ))
108
-
109
- def __call__(self, image):
110
- image = self.standardize(image)
111
- return image
112
-
113
-
114
  class RandomizeImageTransform(object):
115
  """Standardize image and randomly augment"""
116
 
117
- def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT, random_magnitude=5):
118
- assert random_magnitude > 0
119
- eps = 0.01
120
- self.transform = T.Compose((
121
- T.ColorJitter(brightness=((1 - eps) / (random_magnitude + eps), 1 - eps)),
122
- T.Resize(height),
123
- T.Grayscale(),
124
- T.functional.invert,
125
- T.CenterCrop((height, width)),
126
- torch.Tensor.contiguous,
127
- T.RandAugment(magnitude=random_magnitude),
128
- T.ConvertImageDtype(torch.float32)
129
- ))
 
 
 
 
 
 
 
 
130
 
131
  def __call__(self, image):
132
  image = self.transform(image)
@@ -148,7 +120,7 @@ class ExtractEquationFromTexTransform(object):
148
  return equation
149
 
150
 
151
- def generate_tex_tokenizer(dataloader, vocab_size):
152
  """Returns a tokenizer trained on texs from given dataset"""
153
 
154
  texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
@@ -156,7 +128,6 @@ def generate_tex_tokenizer(dataloader, vocab_size):
156
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
157
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
158
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
159
- vocab_size=vocab_size,
160
  special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
161
  )
162
  tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
@@ -174,34 +145,30 @@ def generate_tex_tokenizer(dataloader, vocab_size):
174
 
175
 
176
  class LatexImageDataModule(pl.LightningDataModule):
177
- def __init__(self, batch_size=BATCH_SIZE):
178
  super().__init__()
179
- torch.manual_seed(0)
180
- self.batch_size = batch_size
 
 
 
 
181
 
182
- self.train_dataset = TexImageDataset(
183
- root_dir=DATA_DIR,
184
- image_transform=RandomizeImageTransform(),
185
- tex_transform=ExtractEquationFromTexTransform()
186
- )
187
- self.val_dataset = TexImageDataset(
188
- root_dir=DATA_DIR,
189
- image_transform=RandomizeImageTransform(),
190
- tex_transform=ExtractEquationFromTexTransform()
191
- )
192
- self.test_dataset = TexImageDataset(
193
- root_dir=DATA_DIR,
194
- image_transform=RandomizeImageTransform(),
195
- tex_transform=ExtractEquationFromTexTransform()
196
- )
197
  train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
198
  self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
199
  self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
200
  self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
201
 
202
- self.tex_tokenizer = generate_tex_tokenizer(
203
- DataLoader(self.train_dataset, batch_size=32, num_workers=16),
204
- vocab_size=TEX_VOCAB_SIZE)
 
 
 
 
 
 
 
205
  self.collate_fn = BatchCollator(self.tex_tokenizer)
206
 
207
  @staticmethod
@@ -213,8 +180,9 @@ class LatexImageDataModule(pl.LightningDataModule):
213
  return indices[:train_split], indices[train_split: val_split], indices[val_split:]
214
 
215
  def train_dataloader(self):
216
- return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_fn,
217
- pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS)
 
218
 
219
  def val_dataloader(self):
220
  return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
 
12
  import random
13
  import re
14
 
15
+ TOKENIZER_PATH = "resources/tokenizer.pt"
 
 
 
16
  NUM_WORKERS = 4
17
  PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
18
  PIN_MEMORY = False # probably causes cuda oom error if True
 
57
  return {"image": image, "tex": tex}
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class BatchCollator(object):
61
  """Image, tex batch collator"""
62
 
 
75
  return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class RandomizeImageTransform(object):
79
  """Standardize image and randomly augment"""
80
 
81
+ def __init__(self, width, height, random_magnitude):
82
+ if random_magnitude > 0:
83
+ self.transform = T.Compose((
84
+ T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
85
+ saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
86
+ T.Resize(height),
87
+ T.Grayscale(),
88
+ T.functional.invert,
89
+ T.CenterCrop((height, width)),
90
+ torch.Tensor.contiguous,
91
+ T.RandAugment(magnitude=random_magnitude),
92
+ T.ConvertImageDtype(torch.float32)
93
+ ))
94
+ else:
95
+ self.transform = T.Compose((
96
+ T.Resize(height),
97
+ T.Grayscale(),
98
+ T.functional.invert,
99
+ T.CenterCrop((height, width)),
100
+ T.ConvertImageDtype(torch.float32)
101
+ ))
102
 
103
  def __call__(self, image):
104
  image = self.transform(image)
 
120
  return equation
121
 
122
 
123
+ def generate_tex_tokenizer(dataloader):
124
  """Returns a tokenizer trained on texs from given dataset"""
125
 
126
  texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
 
128
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
129
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
130
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
 
131
  special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
132
  )
133
  tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
 
145
 
146
 
147
  class LatexImageDataModule(pl.LightningDataModule):
148
+ def __init__(self, image_width, image_height, batch_size, random_magnitude):
149
  super().__init__()
150
+ image_transform = RandomizeImageTransform(image_width, image_height, random_magnitude)
151
+ tex_transform = ExtractEquationFromTexTransform()
152
+
153
+ self.train_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
154
+ self.val_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
155
+ self.test_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
158
  self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
159
  self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
160
  self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
161
 
162
+ self.batch_size = batch_size
163
+ self.save_hyperparameters()
164
+
165
+ def prepare_data(self):
166
+ tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
167
+ print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
168
+ torch.save(tokenizer, TOKENIZER_PATH)
169
+
170
+ def setup(self, stage=None):
171
+ self.tex_tokenizer = torch.load(TOKENIZER_PATH)
172
  self.collate_fn = BatchCollator(self.tex_tokenizer)
173
 
174
  @staticmethod
 
180
  return indices[:train_split], indices[train_split: val_split], indices[val_split:]
181
 
182
  def train_dataloader(self):
183
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
184
+ pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS,
185
+ shuffle=True)
186
 
187
  def val_dataloader(self):
188
  return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
generate.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_generator import generate_data
2
+
3
+ import argparse
4
+
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description="Clear old dataset and generate new one")
8
+ parser.add_argument("size", help="size of new dataset", type=int)
9
+ parser.add_argument("depth", help="max_depth scope depth of generated equation, no less than 1", type=int)
10
+ parser.add_argument("length", help="length of equation will be in range length/2..length", type=int)
11
+ parser.add_argument("fraction", help="fraction of tex vocab to sample tokens from, float in range 0..1", type=float)
12
+ args = parser.parse_args()
13
+ return args
14
+
15
+
16
+ def main():
17
+ args = parse_args()
18
+ generate_data(examples_count=args.size, max_depth=args.depth, equation_length=args.length,
19
+ distribution_fraction=args.fraction)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ main()
model.py CHANGED
@@ -11,10 +11,6 @@ class AddPositionalEncoding(nn.Module):
11
  def __init__(self, d_model, max_sequence_len=5000):
12
  super().__init__()
13
 
14
- # pos - position in sequence, i - index of element embedding
15
- # PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
16
- # PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
17
-
18
  positions = torch.arange(max_sequence_len)
19
  even_embedding_indices = torch.arange(0, d_model, 2)
20
 
@@ -103,7 +99,7 @@ class Transformer(pl.LightningModule):
103
  def __init__(self,
104
  num_encoder_layers: int,
105
  num_decoder_layers: int,
106
- emb_size: int,
107
  nhead: int,
108
  image_width: int,
109
  image_height: int,
@@ -114,7 +110,7 @@ class Transformer(pl.LightningModule):
114
  ):
115
  super().__init__()
116
 
117
- self.transformer = nn.Transformer(d_model=emb_size,
118
  nhead=nhead,
119
  num_encoder_layers=num_encoder_layers,
120
  num_decoder_layers=num_decoder_layers,
@@ -125,10 +121,10 @@ class Transformer(pl.LightningModule):
125
  if p.dim() > 1:
126
  nn.init.xavier_uniform_(p)
127
 
128
- self.d_model = emb_size
129
- self.src_tok_emb = ImageEmbedding(emb_size, image_width, image_height, patch_size=16, dropout=dropout)
130
- self.tgt_tok_emb = TexEmbedding(emb_size, tgt_vocab_size, dropout=dropout)
131
- self.generator = nn.Linear(emb_size, tgt_vocab_size)
132
  # Make embedding and generator share weight because they do the same thing
133
  self.tgt_tok_emb.embedding.weight = self.generator.weight
134
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
 
11
  def __init__(self, d_model, max_sequence_len=5000):
12
  super().__init__()
13
 
 
 
 
 
14
  positions = torch.arange(max_sequence_len)
15
  even_embedding_indices = torch.arange(0, d_model, 2)
16
 
 
99
  def __init__(self,
100
  num_encoder_layers: int,
101
  num_decoder_layers: int,
102
+ d_model: int,
103
  nhead: int,
104
  image_width: int,
105
  image_height: int,
 
110
  ):
111
  super().__init__()
112
 
113
+ self.transformer = nn.Transformer(d_model=d_model,
114
  nhead=nhead,
115
  num_encoder_layers=num_encoder_layers,
116
  num_decoder_layers=num_decoder_layers,
 
121
  if p.dim() > 1:
122
  nn.init.xavier_uniform_(p)
123
 
124
+ self.d_model = d_model
125
+ self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
126
+ self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
127
+ self.generator = nn.Linear(d_model, tgt_vocab_size)
128
  # Make embedding and generator share weight because they do the same thing
129
  self.tgt_tok_emb.embedding.weight = self.generator.weight
130
  self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
train.py CHANGED
@@ -1,73 +1,57 @@
1
- from data_generator import generate_data
2
- from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
3
  from model import Transformer
4
  from utils import LogImageTexCallback
5
 
6
  import argparse
 
7
  from pytorch_lightning.callbacks import LearningRateMonitor
8
  from pytorch_lightning.loggers import WandbLogger
9
- from pytorch_lightning import Trainer, seed_everything
10
  import torch
11
 
12
- DATASET_PATH = "resources/dataset.pt"
13
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
14
- TUNER_DIR = "resources/pl_tuner_checkpoints"
15
- BEST_MODEL_CHECKPOINT = "best_model.ckpt"
16
 
17
 
 
 
 
 
18
  def parse_args():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument(
21
- "-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs"
22
- )
23
- parser.add_argument(
24
- "-n", "-new-dataset", help="clear old dataset and generate provided number of new examples", type=int,
25
- dest="new_dataset"
26
- )
27
- parser.add_argument(
28
- "-g", "-gpus", metavar="GPUS", help="ids of gpus to train on, if not provided then trains on cpu", nargs="+",
29
- type=int, dest="gpus", choices=list(range(torch.cuda.device_count())),
30
- )
31
- parser.add_argument(
32
- "-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
33
- action="store_true", dest="log"
34
- )
35
- parser.add_argument(
36
- "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
37
- action="store_true", dest="deterministic"
38
- )
39
- # parser.add_argument(
40
- # "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
41
- # action="store_true", dest="tune"
42
- # )
43
 
44
  args = parser.parse_args()
 
 
 
45
  return args
46
 
47
 
48
- # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
49
- # crop image, adjust brightness, lr warmup?, make tex tokens always decodable,
50
- # take loss that doesn't punish so much for offsets, take a look at weights,
51
-
52
  def main():
53
  args = parse_args()
54
-
55
- if args.deterministic:
56
- seed_everything(42, workers=True)
57
-
58
- if args.new_dataset is not None:
59
- generate_data(args.new_dataset)
60
- datamodule = LatexImageDataModule()
61
- torch.save(datamodule, DATASET_PATH)
62
- else:
63
- datamodule = torch.load(DATASET_PATH)
64
-
65
  if args.log:
66
  logger = WandbLogger(f"img2tex", log_model=True)
67
- callbacks = [
68
- LogImageTexCallback(logger, datamodule.tex_tokenizer),
69
- LearningRateMonitor(logging_interval='step')
70
- ]
71
  else:
72
  logger = None
73
  callbacks = []
@@ -79,24 +63,22 @@ def main():
79
  strategy="ddp",
80
  enable_progress_bar=True,
81
  default_root_dir=TRAINER_DIR,
82
- callbacks=callbacks,
83
- )
84
-
85
- transformer = Transformer(num_encoder_layers=3,
86
- num_decoder_layers=3,
87
- emb_size=512,
88
- nhead=8,
89
- image_width=IMAGE_WIDTH,
90
- image_height=IMAGE_HEIGHT,
 
91
  tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
92
- pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"),
93
- dim_feedforward=512,
94
- dropout=0.1,
95
- )
96
 
97
  trainer.fit(transformer, datamodule=datamodule)
98
- trainer.test(datamodule=datamodule)
99
- trainer.save_checkpoint(BEST_MODEL_CHECKPOINT)
100
 
101
 
102
  if __name__ == "__main__":
 
1
+ from data_preprocessing import LatexImageDataModule
 
2
  from model import Transformer
3
  from utils import LogImageTexCallback
4
 
5
  import argparse
6
+ import os
7
  from pytorch_lightning.callbacks import LearningRateMonitor
8
  from pytorch_lightning.loggers import WandbLogger
9
+ from pytorch_lightning import Trainer
10
  import torch
11
 
 
12
  TRAINER_DIR = "resources/pl_trainer_checkpoints"
 
 
13
 
14
 
15
+ # TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
16
+ # crop image, adjust brightness, make tex tokens always decodable,
17
+ # save only datamodule state?, ensemble last checkpoints, early stopping
18
+
19
  def parse_args():
20
+ parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
21
+
22
+ parser.add_argument("-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs")
23
+ parser.add_argument("-g", "-gpus", metavar="GPUS", type=int, choices=list(range(torch.cuda.device_count())),
24
+ help="ids of gpus to train on, if not provided, then trains on cpu", nargs="+", dest="gpus")
25
+ parser.add_argument("-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
26
+ action="store_true", dest="log")
27
+ parser.add_argument("-width", help="width of images, default 1024", default=1024, type=int)
28
+ parser.add_argument("-height", help="height of images, default 128", default=128, type=int)
29
+ parser.add_argument("-r", "-randomize", default=5, type=int, dest="random_magnitude", choices=range(10),
30
+ help="add random augments to images of provided magnitude in range 0..9, default 5")
31
+ parser.add_argument("-b", "-batch-size", help="batch size, default 16", default=16,
32
+ type=int, dest="batch_size")
33
+ transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
34
+ ("dim_feedforward", 2048), ("dropout", 0.1)]
35
+ parser.add_argument("-t", "-transformer-args", dest="transformer_args", nargs='+', default=[],
36
+ help="transformer init args:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
 
 
 
 
 
 
 
37
 
38
  args = parser.parse_args()
39
+ for i, parameter in enumerate(args.transformer_args):
40
+ transformer_args[i][1] = parameter
41
+ args.transformer_args = dict(transformer_args)
42
  return args
43
 
44
 
 
 
 
 
45
  def main():
46
  args = parse_args()
47
+ datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
48
+ batch_size=args.batch_size, random_magnitude=args.random_magnitude)
49
+ datamodule.prepare_data()
50
+ datamodule.setup()
 
 
 
 
 
 
 
51
  if args.log:
52
  logger = WandbLogger(f"img2tex", log_model=True)
53
+ callbacks = [LogImageTexCallback(logger, datamodule.tex_tokenizer),
54
+ LearningRateMonitor(logging_interval='step')]
 
 
55
  else:
56
  logger = None
57
  callbacks = []
 
63
  strategy="ddp",
64
  enable_progress_bar=True,
65
  default_root_dir=TRAINER_DIR,
66
+ callbacks=callbacks)
67
+
68
+ transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
69
+ num_decoder_layers=args.transformer_args['num_decoder_layers'],
70
+ d_model=args.transformer_args['d_model'],
71
+ nhead=args.transformer_args['nhead'],
72
+ dim_feedforward=args.transformer_args['dim_feedforward'],
73
+ dropout=args.transformer_args['dropout'],
74
+ image_width=datamodule.hparams['image_width'],
75
+ image_height=datamodule.hparams['image_height'],
76
  tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
77
+ pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
 
 
 
78
 
79
  trainer.fit(transformer, datamodule=datamodule)
80
+ trainer.test(datamodule=datamodule, ckpt_path='best')
81
+ trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
82
 
83
 
84
  if __name__ == "__main__":
utils.py CHANGED
@@ -22,57 +22,84 @@ class LogImageTexCallback(Callback):
22
  image = self.tensor_to_PIL(image)
23
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
24
  skip_special_tokens=True)
25
- self.logger.log_image(key="samples", images=[image],
26
- caption=[f"True: {tex_true}\nPredicted: {tex_predicted}\nIds: {tex_ids}"])
27
 
28
- # if args.new_dataset:
29
- # datamodule.batch_size = 1
30
- # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
31
- # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
32
- # gpus=args.gpus,
33
- # strategy=TRAINER_STRATEGY,
34
- # enable_progress_bar=True,
35
- # enable_checkpointing=False,
36
- # auto_scale_batch_size=True,
37
- # num_sanity_val_steps=0,
38
- # logger=False
39
- # )
40
- # tuner.tune(transformer_for_tuning, datamodule=datamodule)
41
- # torch.save(datamodule, DATASET_PATH)
42
-
43
- class _TransformerTuner(Transformer):
44
- """
45
- When using trainer.tune, batches from dataloader get passed directly to forward,
46
- so this subclass takes care of that
47
- """
48
 
49
- def forward(self, batch, batch_idx):
50
- src = batch['images']
51
- tgt = batch['tex_ids']
52
- tgt_input = tgt[:, :-1]
53
- tgt_output = tgt[:, 1:]
54
- src_mask = None
55
- tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
56
- torch.ByteTensor.dtype)
57
- memory_mask = None
58
- src_padding_mask = None
59
- tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
60
- tgt_padding_mask = tgt_padding_mask.masked_fill(
61
- tgt_padding_mask == 0, float('-inf')
62
- ).masked_fill(
63
- tgt_padding_mask == 1, 0
64
- )
65
 
66
- src = self.src_tok_emb(src)
67
- tgt_input = self.tgt_tok_emb(tgt_input)
68
- outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
69
- outs = self.generator(outs)
70
-
71
- loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
72
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- def validation_step(self, batch, batch_idx):
75
- return self(batch, batch_idx)
76
 
77
  @torch.inference_mode()
78
  def decode(transformer, tex_tokenizer, image):
@@ -87,4 +114,4 @@ def decode(transformer, tex_tokenizer, image):
87
  next_id = outs[0, :, -1].argmax().item()
88
  tex_ids.append(next_id)
89
  tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
90
- return tex, tex_ids
 
22
  image = self.tensor_to_PIL(image)
23
  tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
24
  skip_special_tokens=True)
25
+ self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"])
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # parser.add_argument(
29
+ # "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
30
+ # action="store_true", dest="tune"
31
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # if args.new_dataset:
34
+ # datamodule.batch_size = 1
35
+ # transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
36
+ # tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
37
+ # gpus=args.gpus,
38
+ # strategy=TRAINER_STRATEGY,
39
+ # enable_progress_bar=True,
40
+ # enable_checkpointing=False,
41
+ # auto_scale_batch_size=True,
42
+ # num_sanity_val_steps=0,
43
+ # logger=False
44
+ # )
45
+ # tuner.tune(transformer_for_tuning, datamodule=datamodule)
46
+ # torch.save(datamodule, DATASET_PATH)
47
+ # TUNER_DIR = "resources/pl_tuner_checkpoints"
48
+ # from pytorch_lightning import seed_everything
49
+ # parser.add_argument(
50
+ # "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
51
+ # action="store_true", dest="deterministic"
52
+ # )
53
+ # if args.deterministic:
54
+ # seed_everything(42, workers=True)
55
+ # def generate_normalize_transform(dataset: TexImageDataset):
56
+ # """Returns a normalize layer with mean and std computed after iterating over dataset"""
57
+ #
58
+ # mean = 0
59
+ # std = 0
60
+ # for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
61
+ # image = item['image']
62
+ # mean += image.mean()
63
+ # std += image.std()
64
+ #
65
+ # mean /= len(dataset)
66
+ # std /= len(dataset)
67
+ # normalize = T.Normalize(mean, std)
68
+ # return normalize
69
+ # class _TransformerTuner(Transformer):
70
+ # """
71
+ # When using trainer.tune, batches from dataloader get passed directly to forward,
72
+ # so this subclass takes care of that
73
+ # """
74
+ #
75
+ # def forward(self, batch, batch_idx):
76
+ # src = batch['images']
77
+ # tgt = batch['tex_ids']
78
+ # tgt_input = tgt[:, :-1]
79
+ # tgt_output = tgt[:, 1:]
80
+ # src_mask = None
81
+ # tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
82
+ # torch.ByteTensor.dtype)
83
+ # memory_mask = None
84
+ # src_padding_mask = None
85
+ # tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
86
+ # tgt_padding_mask = tgt_padding_mask.masked_fill(
87
+ # tgt_padding_mask == 0, float('-inf')
88
+ # ).masked_fill(
89
+ # tgt_padding_mask == 1, 0
90
+ # )
91
+ #
92
+ # src = self.src_tok_emb(src)
93
+ # tgt_input = self.tgt_tok_emb(tgt_input)
94
+ # outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
95
+ # outs = self.generator(outs)
96
+ #
97
+ # loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
98
+ # return loss
99
+ #
100
+ # def validation_step(self, batch, batch_idx):
101
+ # return self(batch, batch_idx)
102
 
 
 
103
 
104
  @torch.inference_mode()
105
  def decode(transformer, tex_tokenizer, image):
 
114
  next_id = outs[0, :, -1].argmax().item()
115
  tex_ids.append(next_id)
116
  tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
117
+ return tex, tex_ids