dkoshman commited on
Commit
6e82d4a
β€’
1 Parent(s): 41c9661

data_preprocessing, base train script

Browse files
latex_generator.py β†’ data_generator.py RENAMED
@@ -11,14 +11,15 @@ class DotDict(dict):
11
  __getattr__ = dict.get
12
  __setattr__ = dict.__setitem__
13
  __delattr__ = dict.__delitem__
14
-
15
  def __init__(self, *args, **kwargs):
16
  super().__init__(*args, **kwargs)
17
  if len(args) > 0 and isinstance(args[0], dict):
18
  for key, value in self.items():
19
  if isinstance(value, dict):
20
  self.__setitem__(key, DotDict(value))
21
-
 
22
  def _generate_equation(size_left, depth_left, latex, tokens):
23
  if size_left <= 0:
24
  return ""
@@ -27,17 +28,17 @@ def _generate_equation(size_left, depth_left, latex, tokens):
27
  pairs, scopes, special = latex.pairs, latex.scopes, latex.special
28
  weights = [3, depth_left > 0, depth_left > 0]
29
  group, = random.choices([tokens, pairs, scopes], weights=weights)
30
-
31
  if group is tokens:
32
  equation += ' '.join([
33
  random.choice(tokens),
34
  _generate_equation(size_left - 1, depth_left, latex, tokens)
35
  ])
36
  return equation
37
-
38
  post_scope_size = round(abs(random.gauss(0, size_left / 2)))
39
  size_left -= post_scope_size + 1
40
-
41
  if group is pairs:
42
  pair = random.choice(pairs)
43
  equation += ' '.join([
@@ -47,18 +48,18 @@ def _generate_equation(size_left, depth_left, latex, tokens):
47
  _generate_equation(post_scope_size, depth_left, latex, tokens)
48
  ])
49
  return equation
50
-
51
  elif group is scopes:
52
  scope_type, scope_group = random.choice(list(scopes.items()))
53
  scope_operator = random.choice(scope_group)
54
  equation += scope_operator
55
-
56
  if scope_type == 'single':
57
  equation += ' '.join([
58
  special.left_bracket,
59
  _generate_equation(size_left, depth_left - 1, latex, tokens)
60
  ])
61
-
62
  elif scope_type == 'double_no_delimiters':
63
  equation += ' '.join([
64
  special.left_bracket,
@@ -66,7 +67,7 @@ def _generate_equation(size_left, depth_left, latex, tokens):
66
  special.right_bracket + special.left_bracket,
67
  _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
68
  ])
69
-
70
  elif scope_type == 'double_with_delimiters':
71
  equation += ' '.join([
72
  special.caret,
@@ -77,14 +78,15 @@ def _generate_equation(size_left, depth_left, latex, tokens):
77
  special.left_bracket,
78
  _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
79
  ])
80
-
81
  equation += ' '.join([
82
  special.right_bracket,
83
  _generate_equation(post_scope_size, depth_left, latex, tokens)
84
  ])
85
  return equation
86
-
87
- def generate_equation(latex: dict, size, depth=3):
 
88
  """
89
  Generates a random latex equation
90
  -------
@@ -98,6 +100,7 @@ def generate_equation(latex: dict, size, depth=3):
98
  equation = _generate_equation(size, depth, latex, tokens)
99
  return equation
100
 
 
101
  def generate_image(directory: str, latex_path: str, filename: str, max_length=20):
102
  """
103
  Generates a random tex file and corresponding image
@@ -108,29 +111,29 @@ def generate_image(directory: str, latex_path: str, filename: str, max_length=20
108
  :filename: -- name for the generated files
109
  :max_length: -- max size of equation
110
  """
111
- #TODO ARGPARSE, path parse
112
  filepath = directory + filename
113
-
114
  with open(latex_path) as file:
115
  latex = json.load(file)
116
  latex = DotDict(latex)
117
-
118
  template = string.Template(latex.template)
119
  font, font_options = random.choice(latex.fonts)
120
  font_option = random.choice([''] + font_options)
121
  fontsize = random.choice(latex.fontsizes)
122
- equation = generate_equation(latex, 20)
123
  tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
124
-
125
  files_before = set(os.listdir(directory))
126
  with open(f"{filepath}.tex", mode='w') as file:
127
  file.write(tex)
128
-
129
  pr1 = subprocess.run(
130
  f"pdflatex -output-directory={directory} {filepath}.tex".split(),
131
  stderr=subprocess.PIPE,
132
  )
133
-
134
  files_after = set(os.listdir(directory))
135
  if pr1.returncode != 0:
136
  files_to_delete = files_after - files_before
@@ -138,41 +141,43 @@ def generate_image(directory: str, latex_path: str, filename: str, max_length=20
138
  subprocess.run(['rm'] + [directory + file for file in files_to_delete])
139
  print(pr1.stderr.decode(), tex)
140
  return
141
-
142
  pr2 = subprocess.run(
143
  f"gs -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
144
  stderr=subprocess.PIPE,
145
  )
146
-
147
- files_to_delete = files_after - files_before - set([filename + '.png', filename + '.tex'])
148
  if files_to_delete:
149
  subprocess.run(['rm'] + [directory + file for file in files_to_delete])
150
- assert(pr2.returncode == 0)
151
-
 
152
  def generate_dataset(
153
- filenames,
154
- directory="/external2/dkkoshman/repos/ML2TransformerApp/data/",
155
- latex_path="/external2/dkkoshman/repos/ML2TransformerApp/resources/latex.json",
156
- overwrite: bool=False
157
- ):
 
 
158
  """
159
- Generates a latex dataset
160
  -------
161
  params:
162
  :filenames: - iterable of filenames to create, without extension
163
  :directory: - where to create
164
  :latex_path: - full path to latex json
165
- :ovewrite: - whether to overwrite exsisting files
166
  """
167
-
168
  filenames = set(filenames)
169
  if not overwrite:
170
  existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
171
  filenames -= existing
172
-
173
  while filenames:
174
  with Pool() as pool:
175
  pool.starmap(generate_image, ((directory, latex_path, name) for name in filenames))
176
  existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
177
  filenames -= existing
178
-
 
11
  __getattr__ = dict.get
12
  __setattr__ = dict.__setitem__
13
  __delattr__ = dict.__delitem__
14
+
15
  def __init__(self, *args, **kwargs):
16
  super().__init__(*args, **kwargs)
17
  if len(args) > 0 and isinstance(args[0], dict):
18
  for key, value in self.items():
19
  if isinstance(value, dict):
20
  self.__setitem__(key, DotDict(value))
21
+
22
+
23
  def _generate_equation(size_left, depth_left, latex, tokens):
24
  if size_left <= 0:
25
  return ""
 
28
  pairs, scopes, special = latex.pairs, latex.scopes, latex.special
29
  weights = [3, depth_left > 0, depth_left > 0]
30
  group, = random.choices([tokens, pairs, scopes], weights=weights)
31
+
32
  if group is tokens:
33
  equation += ' '.join([
34
  random.choice(tokens),
35
  _generate_equation(size_left - 1, depth_left, latex, tokens)
36
  ])
37
  return equation
38
+
39
  post_scope_size = round(abs(random.gauss(0, size_left / 2)))
40
  size_left -= post_scope_size + 1
41
+
42
  if group is pairs:
43
  pair = random.choice(pairs)
44
  equation += ' '.join([
 
48
  _generate_equation(post_scope_size, depth_left, latex, tokens)
49
  ])
50
  return equation
51
+
52
  elif group is scopes:
53
  scope_type, scope_group = random.choice(list(scopes.items()))
54
  scope_operator = random.choice(scope_group)
55
  equation += scope_operator
56
+
57
  if scope_type == 'single':
58
  equation += ' '.join([
59
  special.left_bracket,
60
  _generate_equation(size_left, depth_left - 1, latex, tokens)
61
  ])
62
+
63
  elif scope_type == 'double_no_delimiters':
64
  equation += ' '.join([
65
  special.left_bracket,
 
67
  special.right_bracket + special.left_bracket,
68
  _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
69
  ])
70
+
71
  elif scope_type == 'double_with_delimiters':
72
  equation += ' '.join([
73
  special.caret,
 
78
  special.left_bracket,
79
  _generate_equation(size_left // 2, depth_left - 1, latex, tokens)
80
  ])
81
+
82
  equation += ' '.join([
83
  special.right_bracket,
84
  _generate_equation(post_scope_size, depth_left, latex, tokens)
85
  ])
86
  return equation
87
+
88
+
89
+ def generate_equation(latex: DotDict, size, depth=3):
90
  """
91
  Generates a random latex equation
92
  -------
 
100
  equation = _generate_equation(size, depth, latex, tokens)
101
  return equation
102
 
103
+
104
  def generate_image(directory: str, latex_path: str, filename: str, max_length=20):
105
  """
106
  Generates a random tex file and corresponding image
 
111
  :filename: -- name for the generated files
112
  :max_length: -- max size of equation
113
  """
114
+ # TODO ARGPARSE, path parse
115
  filepath = directory + filename
116
+
117
  with open(latex_path) as file:
118
  latex = json.load(file)
119
  latex = DotDict(latex)
120
+
121
  template = string.Template(latex.template)
122
  font, font_options = random.choice(latex.fonts)
123
  font_option = random.choice([''] + font_options)
124
  fontsize = random.choice(latex.fontsizes)
125
+ equation = generate_equation(latex, max_length)
126
  tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
127
+
128
  files_before = set(os.listdir(directory))
129
  with open(f"{filepath}.tex", mode='w') as file:
130
  file.write(tex)
131
+
132
  pr1 = subprocess.run(
133
  f"pdflatex -output-directory={directory} {filepath}.tex".split(),
134
  stderr=subprocess.PIPE,
135
  )
136
+
137
  files_after = set(os.listdir(directory))
138
  if pr1.returncode != 0:
139
  files_to_delete = files_after - files_before
 
141
  subprocess.run(['rm'] + [directory + file for file in files_to_delete])
142
  print(pr1.stderr.decode(), tex)
143
  return
144
+
145
  pr2 = subprocess.run(
146
  f"gs -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
147
  stderr=subprocess.PIPE,
148
  )
149
+
150
+ files_to_delete = files_after - files_before - {filename + '.png', filename + '.tex'}
151
  if files_to_delete:
152
  subprocess.run(['rm'] + [directory + file for file in files_to_delete])
153
+ assert (pr2.returncode == 0)
154
+
155
+
156
  def generate_dataset(
157
+ filenames: iter(str),
158
+ directory: str = "/external2/dkkoshman/repos/ML2TransformerApp/data/",
159
+ latex_path: str = "/external2/dkkoshman/repos/ML2TransformerApp/resources/latex.json",
160
+ overwrite: bool = False
161
+ ) -> None:
162
+
163
+
164
  """
165
+ Generates a latex dataset in given directory
166
  -------
167
  params:
168
  :filenames: - iterable of filenames to create, without extension
169
  :directory: - where to create
170
  :latex_path: - full path to latex json
171
+ :overwrite: - whether to overwrite existing files
172
  """
173
+
174
  filenames = set(filenames)
175
  if not overwrite:
176
  existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
177
  filenames -= existing
178
+
179
  while filenames:
180
  with Pool() as pool:
181
  pool.starmap(generate_image, ((directory, latex_path, name) for name in filenames))
182
  existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
183
  filenames -= existing
 
data_preprocessing.py CHANGED
@@ -1,96 +1,139 @@
 
1
  import os
2
- import re
3
  import tokenizers
4
  import torch
5
  import torchvision
6
  import torchvision.transforms as T
 
7
  import tqdm
8
- import PIL
9
- from torch.utils.data import Dataset, DataLoader
10
-
11
 
12
- directory = "/external2/dkkoshman/repos/ML2TransformerApp/data/"
13
 
14
  class TexImageDataset(Dataset):
15
  """Image to tex dataset."""
16
 
17
- def __init__(self, root_dir, image_preprocessing=None, tex_preprocessing=None):
18
  """
19
  Args:
20
  root_dir (string): Directory with all the images and tex files.
21
- transform (callable, optional): Optional transform to be applied
22
- on a sample.
23
-
24
- image_preprocessing: callable image preprocessing
25
-
26
- tex_preprocessing: callable tex preprocessing
27
  """
28
-
29
  torch.multiprocessing.set_sharing_strategy('file_system')
30
  self.root_dir = root_dir
31
- filenames = sorted(
32
- set(os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png'))
33
- )
34
- self.data = []
35
-
36
- for filename in tqdm.tqdm(filenames):
37
- tex_path = self.root_dir + filename + '.tex'
38
- image_path = self.root_dir + filename + '.png'
39
-
40
- with open(tex_path) as file:
41
- tex = file.read()
42
- if tex_preprocessing:
43
- tex = tex_preprocessing(tex)
44
-
45
- image = torchvision.io.read_image(image_path)
46
- if image_preprocessing:
47
- image = image_preprocessing(image)
48
-
49
- self.data.append((image, tex))
50
 
51
  def __len__(self):
52
- return len(self.data)
53
 
54
  def __getitem__(self, idx):
55
- image, tex = self.data[idx]
 
 
 
 
 
 
 
 
 
 
 
 
56
  return {"image": image, "tex": tex}
57
-
58
-
59
- class StandardizeImage(object):
60
- """Pad and crop image to a given size, invert and normalize"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def __init__(self, width=1024, height=128):
63
- self.transform = T.Compose((
64
  T.Resize(height),
65
  T.Grayscale(),
66
  T.functional.invert,
67
- T.CenterCrop((height, width))
 
68
  ))
69
 
70
  def __call__(self, image):
71
- image = self.transform(image)
72
  return image
73
-
74
-
75
- class RandomTransformImage(object):
76
  """Standardize image and randomly augment"""
77
 
78
- def __init__(self, standardize, random_magnitude=5):
79
- self.brighten = T.ColorJitter(brightness=(1/random_magnitude, 1 + 1/random_magnitude))
80
- self.standardize = standardize
81
- self.rand_aug = T.RandAugment(magnitude=random_magnitude)
 
 
 
 
 
 
 
82
 
83
  def __call__(self, image):
84
- image = self.brighten(image)
85
- image = self.standardize(image)
86
- image = image.contiguous()
87
- image = self.rand_aug(image)
88
  return image
89
 
90
 
91
- def generate_tex_tokenizer(texs):
92
- """Returns a tokeniser trained on tex strings from dataset"""
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
95
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
96
  vocab_size=300,
@@ -103,5 +146,5 @@ def generate_tex_tokenizer(texs):
103
  special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
104
  )
105
  tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
106
-
107
  return tokenizer
 
1
+ import einops
2
  import os
 
3
  import tokenizers
4
  import torch
5
  import torchvision
6
  import torchvision.transforms as T
7
+ from torch.utils.data import Dataset
8
  import tqdm
9
+ import re
 
 
10
 
 
11
 
12
  class TexImageDataset(Dataset):
13
  """Image to tex dataset."""
14
 
15
+ def __init__(self, root_dir, image_transform=None, tex_transform=None):
16
  """
17
  Args:
18
  root_dir (string): Directory with all the images and tex files.
19
+ image_transform: callable image preprocessing
20
+ tex_transform: callable tex preprocessing
 
 
 
 
21
  """
22
+
23
  torch.multiprocessing.set_sharing_strategy('file_system')
24
  self.root_dir = root_dir
25
+ self.filenames = sorted(set(
26
+ os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png')
27
+ ))
28
+ self.image_transform = image_transform
29
+ self.tex_transform = tex_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def __len__(self):
32
+ return len(self.filenames)
33
 
34
  def __getitem__(self, idx):
35
+ filename = self.filenames[idx]
36
+ tex_path = self.root_dir + filename + '.tex'
37
+ image_path = self.root_dir + filename + '.png'
38
+
39
+ with open(tex_path) as file:
40
+ tex = file.read()
41
+ if self.tex_transform:
42
+ tex = self.tex_transform(tex)
43
+
44
+ image = torchvision.io.read_image(image_path)
45
+ if self.image_transform:
46
+ image = self.image_transform(image)
47
+
48
  return {"image": image, "tex": tex}
49
+
50
+ def subjoin_normalize_layer(self):
51
+ """Appends a normalize layer with mean and std computed after iterating over dataset"""
52
+ mean = 0
53
+ std = 0
54
+ for item in tqdm.tqdm(self):
55
+ image = item['image']
56
+ mean += image.mean()
57
+ std += image.std()
58
+
59
+ mean /= len(self)
60
+ std /= len(self)
61
+ normalize = T.Normalize(mean, std)
62
+
63
+ if self.image_transform:
64
+ self.image_transform = T.Compose((self.image_transform, normalize))
65
+ else:
66
+ self.image_transform = normalize
67
+
68
+ @staticmethod
69
+ def collate_batch(batch):
70
+ images = [i['image'] for i in batch]
71
+ images = einops.rearrange(images, 'b c h w -> b c h w')
72
+
73
+ texs = [item['tex'] for item in batch]
74
+ texs = tokenizer.encode_batch(texs)
75
+ tex_ids = torch.Tensor([encoding.ids for encoding in texs])
76
+ attention_masks = torch.Tensor([encoding.attention_mask for encoding in texs])
77
+
78
+ return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
79
+
80
+
81
+ class StandardizeImageTransform(object):
82
+ """Pad and crop image to a given size, grayscale and invert"""
83
 
84
  def __init__(self, width=1024, height=128):
85
+ self.standardize = T.Compose((
86
  T.Resize(height),
87
  T.Grayscale(),
88
  T.functional.invert,
89
+ T.CenterCrop((height, width)),
90
+ T.ConvertImageDtype(torch.float32)
91
  ))
92
 
93
  def __call__(self, image):
94
+ image = self.standardize(image)
95
  return image
96
+
97
+
98
+ class RandomizeImageTransform(object):
99
  """Standardize image and randomly augment"""
100
 
101
+ def __init__(self, width=1024, height=128, random_magnitude=5):
102
+ self.transform = T.Compose((
103
+ T.ColorJitter(brightness=random_magnitude / 10),
104
+ T.Resize(height),
105
+ T.Grayscale(),
106
+ T.functional.invert,
107
+ T.CenterCrop((height, width)),
108
+ torch.Tensor.contiguous,
109
+ T.RandAugment(magnitude=random_magnitude),
110
+ T.ConvertImageDtype(torch.float32)
111
+ ))
112
 
113
  def __call__(self, image):
114
+ image = self.transform(image)
 
 
 
115
  return image
116
 
117
 
118
+ class ExtractEquationFromTexTransform(object):
119
+ """Extracts ...\[ equation \]... from tex file"""
120
+
121
+ def __init__(self):
122
+ self.equation_pattern = re.compile(r'\\\[(?P<equation>.*)\\\]', flags=re.DOTALL)
123
+ self.spaces = re.compile(r' +')
124
+
125
+ def __call__(self, tex):
126
+ equation = self.equation_pattern.search(tex)
127
+ equation = equation.group('equation')
128
+ equation = equation.strip()
129
+ equation = self.spaces.sub(' ', equation)
130
+ return equation
131
+
132
+
133
+ def generate_tex_tokenizer(texs: iter(str)):
134
+ """Returns a tokenizer trained on given tex strings"""
135
+
136
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
137
  tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
138
  tokenizer_trainer = tokenizers.trainers.BpeTrainer(
139
  vocab_size=300,
 
146
  special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
147
  )
148
  tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
149
+
150
  return tokenizer
model.py ADDED
File without changes
resources/latex.json CHANGED
@@ -1 +1,257 @@
1
- {"special": {"dollar": "$", "underscore": "_", "caret": "^", "left_bracket": "{", "right_bracket": "}", "ampersand": "&"}, "chars": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"'()*+,-./:;<=>?@[]`|~", "greek": ["\\alpha", "\\beta", "\\gamma", "\\delta", "\\epsilon", "\\varepsilon", "\\zeta", "\\eta", "\\theta", "\\vartheta", "\\iota", "\\kappa", "\\lambda", "\\mu", "\\nu", "\\xi", "\\pi", "\\varpi", "\\rho", "\\varrho", "\\sigma", "\\varsigma", "\\tau", "\\upsilon", "\\phi", "\\varphi", "\\chi", "\\psi", "\\omega", "\\Gamma", "\\Delta", "\\Theta", "\\Lambda", "\\Xi", "\\Pi", "\\Sigma", "\\Upsilon", "\\Phi", "\\Psi", "\\Omega"], "functions": ["\\forall", "\\exists", "\\arccos", "\\arcsin", "\\arctan", "\\cos", "\\cosh", "\\cot", "\\coth", "\\csc", "\\deg", "\\det", "\\dim", "\\exp", "\\gcd", "\\hom", "\\inf", "\\ker", "\\lg", "\\lim", "\\liminf", "\\limsup", "\\ln", "\\log", "\\max", "\\min", "\\sec", "\\sin", "\\sinh", "\\sup", "\\tan", "\\tanh"], "operators": ["--", "---", "\\pm", "\\mp", "\\times", "\\div", "\\ast", "\\star", "\\bullet", "\\circ", "\\cdot", "\\leq", "\\ll", "\\subset", "\\geq", "\\gg", "\\equiv", "\\sim", "\\simeq", "\\approx", "\\neq", "\\propto", "\\not", "\\mid", "\\leftarrow", "\\Leftarrow", "\\longleftarrow", "\\Longleftarrow", "\\rightarrow", "\\Rightarrow", "\\longrightarrow", "\\Longrightarrow", "\\leftrightarrow", "\\Leftrightarrow", "\\longleftrightarrow", "\\uparrow", "\\downarrow", "\\Uparrow", "\\cdots", "\\ddots", "\\ldots", "\\vdots"], "pairs": [["\\left(", "\\right)"], ["\\left[", "\\right]"], ["\\left\\{", "\\right\\}"], ["\\langle", "\\rangle"]], "spaces": ["\\;", "\\:", "\\,", "\\!"], "fonts": [["sfmath", []], ["lmodern", []], ["eulervm", []], ["euler", []], ["beton", []], ["drm", []], ["boisik", []], ["gfsartemisia-euler", []], ["gfsartemisia", []], ["arev", []], ["anttor", ["math", "light,math", "condensed,math", "light,condensed,math"]]], "fontsizes": [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], "template": "\\documentclass[preview]{standalone}\n\\usepackage[$font_option]{$font}\n\\usepackage[T1]{fontenc}\n\\begin{document}\n{\\fontsize{$fontsize pt}{12 pt}\\selectfont \n\\[\n$equation\n\\]\n}\n\\end{document}", "scopes": {"single": ["^", "_", "\\sqrt", "\\underbrace", "\\underline", "\\boldmath", "\\hat", "\\widehat", "\\check", "\\tilde", "\\widetilde", "\\acute", "\\grave", "\\dot", "\\ddot", "\\breve", "\\bar", "\\vec"], "double_with_delimiters": ["\"\\sum", "\\prod", "\\int", "\\bigcup", "\\bigcap"], "double_no_delimiters": ["\\frac", "\\stackrel"]}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "special": {
3
+ "dollar": "$",
4
+ "underscore": "_",
5
+ "caret": "^",
6
+ "left_bracket": "{",
7
+ "right_bracket": "}",
8
+ "ampersand": "&"
9
+ },
10
+ "chars": "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"'()*+,-./:;<=>?@[]`|~",
11
+ "greek": [
12
+ "\\alpha",
13
+ "\\beta",
14
+ "\\gamma",
15
+ "\\delta",
16
+ "\\epsilon",
17
+ "\\varepsilon",
18
+ "\\zeta",
19
+ "\\eta",
20
+ "\\theta",
21
+ "\\vartheta",
22
+ "\\iota",
23
+ "\\kappa",
24
+ "\\lambda",
25
+ "\\mu",
26
+ "\\nu",
27
+ "\\xi",
28
+ "\\pi",
29
+ "\\varpi",
30
+ "\\rho",
31
+ "\\varrho",
32
+ "\\sigma",
33
+ "\\varsigma",
34
+ "\\tau",
35
+ "\\upsilon",
36
+ "\\phi",
37
+ "\\varphi",
38
+ "\\chi",
39
+ "\\psi",
40
+ "\\omega",
41
+ "\\Gamma",
42
+ "\\Delta",
43
+ "\\Theta",
44
+ "\\Lambda",
45
+ "\\Xi",
46
+ "\\Pi",
47
+ "\\Sigma",
48
+ "\\Upsilon",
49
+ "\\Phi",
50
+ "\\Psi",
51
+ "\\Omega"
52
+ ],
53
+ "functions": [
54
+ "\\forall",
55
+ "\\exists",
56
+ "\\arccos",
57
+ "\\arcsin",
58
+ "\\arctan",
59
+ "\\cos",
60
+ "\\cosh",
61
+ "\\cot",
62
+ "\\coth",
63
+ "\\csc",
64
+ "\\deg",
65
+ "\\det",
66
+ "\\dim",
67
+ "\\exp",
68
+ "\\gcd",
69
+ "\\hom",
70
+ "\\inf",
71
+ "\\ker",
72
+ "\\lg",
73
+ "\\lim",
74
+ "\\liminf",
75
+ "\\limsup",
76
+ "\\ln",
77
+ "\\log",
78
+ "\\max",
79
+ "\\min",
80
+ "\\sec",
81
+ "\\sin",
82
+ "\\sinh",
83
+ "\\sup",
84
+ "\\tan",
85
+ "\\tanh"
86
+ ],
87
+ "operators": [
88
+ "--",
89
+ "---",
90
+ "\\pm",
91
+ "\\mp",
92
+ "\\times",
93
+ "\\div",
94
+ "\\ast",
95
+ "\\star",
96
+ "\\bullet",
97
+ "\\circ",
98
+ "\\cdot",
99
+ "\\leq",
100
+ "\\ll",
101
+ "\\subset",
102
+ "\\geq",
103
+ "\\gg",
104
+ "\\equiv",
105
+ "\\sim",
106
+ "\\simeq",
107
+ "\\approx",
108
+ "\\neq",
109
+ "\\propto",
110
+ "\\not",
111
+ "\\mid",
112
+ "\\leftarrow",
113
+ "\\Leftarrow",
114
+ "\\longleftarrow",
115
+ "\\Longleftarrow",
116
+ "\\rightarrow",
117
+ "\\Rightarrow",
118
+ "\\longrightarrow",
119
+ "\\Longrightarrow",
120
+ "\\leftrightarrow",
121
+ "\\Leftrightarrow",
122
+ "\\longleftrightarrow",
123
+ "\\uparrow",
124
+ "\\downarrow",
125
+ "\\Uparrow",
126
+ "\\cdots",
127
+ "\\ddots",
128
+ "\\ldots",
129
+ "\\vdots"
130
+ ],
131
+ "pairs": [
132
+ [
133
+ "\\left(",
134
+ "\\right)"
135
+ ],
136
+ [
137
+ "\\left[",
138
+ "\\right]"
139
+ ],
140
+ [
141
+ "\\left\\{",
142
+ "\\right\\}"
143
+ ],
144
+ [
145
+ "\\langle",
146
+ "\\rangle"
147
+ ]
148
+ ],
149
+ "spaces": [
150
+ "\\;",
151
+ "\\:",
152
+ "\\,",
153
+ "\\!"
154
+ ],
155
+ "fonts": [
156
+ [
157
+ "sfmath",
158
+ []
159
+ ],
160
+ [
161
+ "lmodern",
162
+ []
163
+ ],
164
+ [
165
+ "eulervm",
166
+ []
167
+ ],
168
+ [
169
+ "euler",
170
+ []
171
+ ],
172
+ [
173
+ "beton",
174
+ []
175
+ ],
176
+ [
177
+ "drm",
178
+ []
179
+ ],
180
+ [
181
+ "boisik",
182
+ []
183
+ ],
184
+ [
185
+ "gfsartemisia-euler",
186
+ []
187
+ ],
188
+ [
189
+ "gfsartemisia",
190
+ []
191
+ ],
192
+ [
193
+ "arev",
194
+ []
195
+ ],
196
+ [
197
+ "anttor",
198
+ [
199
+ "math",
200
+ "light,math",
201
+ "condensed,math",
202
+ "light,condensed,math"
203
+ ]
204
+ ]
205
+ ],
206
+ "fontsizes": [
207
+ 6,
208
+ 7,
209
+ 8,
210
+ 9,
211
+ 10,
212
+ 11,
213
+ 12,
214
+ 13,
215
+ 14,
216
+ 15,
217
+ 16,
218
+ 17,
219
+ 18,
220
+ 19,
221
+ 20
222
+ ],
223
+ "template": "\\documentclass[preview]{standalone}\n\\usepackage[$font_option]{$font}\n\\usepackage[T1]{fontenc}\n\\begin{document}\n{\\fontsize{$fontsize pt}{12 pt}\\selectfont \n\\[\n$equation\n\\]\n}\n\\end{document}",
224
+ "scopes": {
225
+ "single": [
226
+ "^",
227
+ "_",
228
+ "\\sqrt",
229
+ "\\underbrace",
230
+ "\\underline",
231
+ "\\boldmath",
232
+ "\\hat",
233
+ "\\widehat",
234
+ "\\check",
235
+ "\\tilde",
236
+ "\\widetilde",
237
+ "\\acute",
238
+ "\\grave",
239
+ "\\dot",
240
+ "\\ddot",
241
+ "\\breve",
242
+ "\\bar",
243
+ "\\vec"
244
+ ],
245
+ "double_with_delimiters": [
246
+ "\"\\sum",
247
+ "\\prod",
248
+ "\\int",
249
+ "\\bigcup",
250
+ "\\bigcap"
251
+ ],
252
+ "double_no_delimiters": [
253
+ "\\frac",
254
+ "\\stackrel"
255
+ ]
256
+ }
257
+ }
train.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+ if __name__ == '__main__':
7
+ image_transform = RandomizeImageTransform()
8
+ tex_transform = ExtractEquationFromTexTransform()
9
+ dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform)
10
+
11
+ train_dataset, test_dataset = torch.utils.data.random_split(
12
+ dataset,
13
+ [len(dataset) * 9 // 10, len(dataset) // 10]
14
+ )
15
+
16
+ train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
17
+ collate_fn=train_dataset.collate_fn)
18
+ batch = next(iter(train_dataloader))
19
+ print(batch['texs'])