dkoshman commited on
Commit
ae308b4
1 Parent(s): e949d7b

script now generates a small dataset

Browse files
Files changed (3) hide show
  1. data_generator.py +3 -5
  2. data_preprocessing.py +20 -20
  3. train.py +14 -7
data_generator.py CHANGED
@@ -153,14 +153,12 @@ def generate_image(directory: str, latex_path: str, filename: str, max_length=20
153
  assert (pr2.returncode == 0)
154
 
155
 
156
- def generate_dataset(
157
  filenames: iter(str),
158
- directory: str = "/external2/dkkoshman/repos/ML2TransformerApp/data/",
159
- latex_path: str = "/external2/dkkoshman/repos/ML2TransformerApp/resources/latex.json",
160
  overwrite: bool = False
161
  ) -> None:
162
-
163
-
164
  """
165
  Generates a latex dataset in given directory
166
  -------
 
153
  assert (pr2.returncode == 0)
154
 
155
 
156
+ def generate_data(
157
  filenames: iter(str),
158
+ directory: str,
159
+ latex_path: str,
160
  overwrite: bool = False
161
  ) -> None:
 
 
162
  """
163
  Generates a latex dataset in given directory
164
  -------
data_preprocessing.py CHANGED
@@ -67,26 +67,6 @@ class TexImageDataset(Dataset):
67
  else:
68
  self.image_transform = normalize
69
 
70
- def subjoin_tex_tokenize_transform(self, texs, vocab_size=300):
71
- """Returns a tokenizer trained on given tex strings"""
72
-
73
- # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
74
- tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
75
- tokenizer_trainer = tokenizers.trainers.BpeTrainer(
76
- vocab_size=vocab_size,
77
- special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
78
- )
79
- tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
80
- tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
81
- tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
82
- single="$A [SEP]",
83
- special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
84
- )
85
- tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
86
-
87
- self.tokenizer = tokenizer
88
- return tokenizer
89
-
90
 
91
  class BatchCollator(object):
92
  """Image, tex batch collator"""
@@ -156,3 +136,23 @@ class ExtractEquationFromTexTransform(object):
156
  equation = equation.strip()
157
  equation = self.spaces.sub(' ', equation)
158
  return equation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  else:
68
  self.image_transform = normalize
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  class BatchCollator(object):
72
  """Image, tex batch collator"""
 
136
  equation = equation.strip()
137
  equation = self.spaces.sub(' ', equation)
138
  return equation
139
+
140
+
141
+ def generate_tex_tokenizer(texs, vocab_size=300):
142
+ """Returns a tokenizer trained on given tex strings"""
143
+
144
+ # os.environ['TOKENIZERS_PARALLELISM'] = 'false'
145
+ tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
146
+ tokenizer_trainer = tokenizers.trainers.BpeTrainer(
147
+ vocab_size=vocab_size,
148
+ special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
149
+ )
150
+ tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
151
+ tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
152
+ tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
153
+ single="$A [SEP]",
154
+ special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
155
+ )
156
+ tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
157
+
158
+ return tokenizer
train.py CHANGED
@@ -1,23 +1,30 @@
 
1
  from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
2
- generate_tex_tokenizer, BatchCollator
3
 
4
  import torch
5
  from torch.utils.data import DataLoader
6
- import tqdm
 
 
7
 
8
  if __name__ == '__main__':
 
 
 
 
 
 
9
  image_transform = RandomizeImageTransform()
10
  tex_transform = ExtractEquationFromTexTransform()
11
- dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform)
12
  dataset.subjoin_image_normalize_transform()
13
  train_dataset, test_dataset = torch.utils.data.random_split(
14
  dataset,
15
  [len(dataset) * 9 // 10, len(dataset) // 10]
16
  )
17
- train_dataloader = DataLoader(train_dataset, batch_size=16, num_workers=16)
18
- texs = list(tqdm.tqdm(batch['tex'] for batch in train_dataloader))
19
- tokenizer = generate_tex_tokenizer(texs)
20
- collate_fn = BatchCollator(tokenizer)
21
 
22
  train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
23
  collate_fn=collate_fn)
 
1
+ from data_generator import generate_data
2
  from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
3
+ BatchCollator, generate_tex_tokenizer
4
 
5
  import torch
6
  from torch.utils.data import DataLoader
7
+
8
+ DATA_DIR = 'data'
9
+ LATEX_PATH = 'resources/latex.json'
10
 
11
  if __name__ == '__main__':
12
+ generate_data(
13
+ filenames=map(str, range(1000)),
14
+ directory=DATA_DIR,
15
+ latex_path=LATEX_PATH,
16
+ )
17
+
18
  image_transform = RandomizeImageTransform()
19
  tex_transform = ExtractEquationFromTexTransform()
20
+ dataset = TexImageDataset(DATA_DIR, image_transform=image_transform, tex_transform=tex_transform)
21
  dataset.subjoin_image_normalize_transform()
22
  train_dataset, test_dataset = torch.utils.data.random_split(
23
  dataset,
24
  [len(dataset) * 9 // 10, len(dataset) // 10]
25
  )
26
+ tex_tokenizer = generate_tex_tokenizer(dataset.texs)
27
+ collate_fn = BatchCollator(tex_tokenizer)
 
 
28
 
29
  train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
30
  collate_fn=collate_fn)