Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
ae308b4
1
Parent(s):
e949d7b
script now generates a small dataset
Browse files- data_generator.py +3 -5
- data_preprocessing.py +20 -20
- train.py +14 -7
data_generator.py
CHANGED
@@ -153,14 +153,12 @@ def generate_image(directory: str, latex_path: str, filename: str, max_length=20
|
|
153 |
assert (pr2.returncode == 0)
|
154 |
|
155 |
|
156 |
-
def
|
157 |
filenames: iter(str),
|
158 |
-
directory: str
|
159 |
-
latex_path: str
|
160 |
overwrite: bool = False
|
161 |
) -> None:
|
162 |
-
|
163 |
-
|
164 |
"""
|
165 |
Generates a latex dataset in given directory
|
166 |
-------
|
|
|
153 |
assert (pr2.returncode == 0)
|
154 |
|
155 |
|
156 |
+
def generate_data(
|
157 |
filenames: iter(str),
|
158 |
+
directory: str,
|
159 |
+
latex_path: str,
|
160 |
overwrite: bool = False
|
161 |
) -> None:
|
|
|
|
|
162 |
"""
|
163 |
Generates a latex dataset in given directory
|
164 |
-------
|
data_preprocessing.py
CHANGED
@@ -67,26 +67,6 @@ class TexImageDataset(Dataset):
|
|
67 |
else:
|
68 |
self.image_transform = normalize
|
69 |
|
70 |
-
def subjoin_tex_tokenize_transform(self, texs, vocab_size=300):
|
71 |
-
"""Returns a tokenizer trained on given tex strings"""
|
72 |
-
|
73 |
-
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
74 |
-
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
75 |
-
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
76 |
-
vocab_size=vocab_size,
|
77 |
-
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
78 |
-
)
|
79 |
-
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
80 |
-
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
81 |
-
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
82 |
-
single="$A [SEP]",
|
83 |
-
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
|
84 |
-
)
|
85 |
-
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
86 |
-
|
87 |
-
self.tokenizer = tokenizer
|
88 |
-
return tokenizer
|
89 |
-
|
90 |
|
91 |
class BatchCollator(object):
|
92 |
"""Image, tex batch collator"""
|
@@ -156,3 +136,23 @@ class ExtractEquationFromTexTransform(object):
|
|
156 |
equation = equation.strip()
|
157 |
equation = self.spaces.sub(' ', equation)
|
158 |
return equation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
else:
|
68 |
self.image_transform = normalize
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
class BatchCollator(object):
|
72 |
"""Image, tex batch collator"""
|
|
|
136 |
equation = equation.strip()
|
137 |
equation = self.spaces.sub(' ', equation)
|
138 |
return equation
|
139 |
+
|
140 |
+
|
141 |
+
def generate_tex_tokenizer(texs, vocab_size=300):
|
142 |
+
"""Returns a tokenizer trained on given tex strings"""
|
143 |
+
|
144 |
+
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
145 |
+
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
146 |
+
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
147 |
+
vocab_size=vocab_size,
|
148 |
+
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
149 |
+
)
|
150 |
+
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
151 |
+
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
|
152 |
+
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
|
153 |
+
single="$A [SEP]",
|
154 |
+
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
|
155 |
+
)
|
156 |
+
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
|
157 |
+
|
158 |
+
return tokenizer
|
train.py
CHANGED
@@ -1,23 +1,30 @@
|
|
|
|
1 |
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
|
2 |
-
|
3 |
|
4 |
import torch
|
5 |
from torch.utils.data import DataLoader
|
6 |
-
|
|
|
|
|
7 |
|
8 |
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
image_transform = RandomizeImageTransform()
|
10 |
tex_transform = ExtractEquationFromTexTransform()
|
11 |
-
dataset = TexImageDataset(
|
12 |
dataset.subjoin_image_normalize_transform()
|
13 |
train_dataset, test_dataset = torch.utils.data.random_split(
|
14 |
dataset,
|
15 |
[len(dataset) * 9 // 10, len(dataset) // 10]
|
16 |
)
|
17 |
-
|
18 |
-
|
19 |
-
tokenizer = generate_tex_tokenizer(texs)
|
20 |
-
collate_fn = BatchCollator(tokenizer)
|
21 |
|
22 |
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
|
23 |
collate_fn=collate_fn)
|
|
|
1 |
+
from data_generator import generate_data
|
2 |
from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \
|
3 |
+
BatchCollator, generate_tex_tokenizer
|
4 |
|
5 |
import torch
|
6 |
from torch.utils.data import DataLoader
|
7 |
+
|
8 |
+
DATA_DIR = 'data'
|
9 |
+
LATEX_PATH = 'resources/latex.json'
|
10 |
|
11 |
if __name__ == '__main__':
|
12 |
+
generate_data(
|
13 |
+
filenames=map(str, range(1000)),
|
14 |
+
directory=DATA_DIR,
|
15 |
+
latex_path=LATEX_PATH,
|
16 |
+
)
|
17 |
+
|
18 |
image_transform = RandomizeImageTransform()
|
19 |
tex_transform = ExtractEquationFromTexTransform()
|
20 |
+
dataset = TexImageDataset(DATA_DIR, image_transform=image_transform, tex_transform=tex_transform)
|
21 |
dataset.subjoin_image_normalize_transform()
|
22 |
train_dataset, test_dataset = torch.utils.data.random_split(
|
23 |
dataset,
|
24 |
[len(dataset) * 9 // 10, len(dataset) // 10]
|
25 |
)
|
26 |
+
tex_tokenizer = generate_tex_tokenizer(dataset.texs)
|
27 |
+
collate_fn = BatchCollator(tex_tokenizer)
|
|
|
|
|
28 |
|
29 |
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16,
|
30 |
collate_fn=collate_fn)
|