Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
41a34cd
1
Parent(s):
e932abd
dedicated generate.py script
Browse files- data_generator.py +100 -135
- data_preprocessing.py +43 -75
- generate.py +23 -0
- model.py +6 -10
- train.py +46 -64
- utils.py +75 -48
data_generator.py
CHANGED
@@ -7,109 +7,76 @@ import subprocess
|
|
7 |
import random
|
8 |
import tqdm
|
9 |
|
10 |
-
DATA_DIR =
|
11 |
-
LATEX_PATH =
|
12 |
-
|
13 |
-
|
14 |
-
class DotDict(dict):
|
15 |
-
"""dot.notation access to dictionary attributes"""
|
16 |
-
__getattr__ = dict.get
|
17 |
-
__setattr__ = dict.__setitem__
|
18 |
-
__delattr__ = dict.__delitem__
|
19 |
-
|
20 |
-
def __init__(self, *args, **kwargs):
|
21 |
-
super().__init__(*args, **kwargs)
|
22 |
-
if len(args) > 0 and isinstance(args[0], dict):
|
23 |
-
for key, value in self.items():
|
24 |
-
if isinstance(value, dict):
|
25 |
-
self.__setitem__(key, DotDict(value))
|
26 |
-
|
27 |
-
|
28 |
-
def _generate_equation(size_left, depth_left, latex, tokens):
|
29 |
-
if size_left <= 0:
|
30 |
-
return ""
|
31 |
-
|
32 |
-
equation = ""
|
33 |
-
pairs, scopes, special = latex.pairs, latex.scopes, latex.special
|
34 |
-
weights = [3, depth_left > 0, depth_left > 0]
|
35 |
-
group, = random.choices([tokens, pairs, scopes], weights=weights)
|
36 |
-
|
37 |
-
if group is tokens:
|
38 |
-
equation += ' '.join([
|
39 |
-
random.choice(tokens),
|
40 |
-
_generate_equation(size_left - 1, depth_left, latex, tokens)
|
41 |
-
])
|
42 |
-
return equation
|
43 |
-
|
44 |
-
post_scope_size = round(abs(random.gauss(0, size_left / 2)))
|
45 |
-
size_left -= post_scope_size + 1
|
46 |
-
|
47 |
-
if group is pairs:
|
48 |
-
pair = random.choice(pairs)
|
49 |
-
equation += ' '.join([
|
50 |
-
pair[0],
|
51 |
-
_generate_equation(size_left, depth_left - 1, latex, tokens),
|
52 |
-
pair[1],
|
53 |
-
_generate_equation(post_scope_size, depth_left, latex, tokens)
|
54 |
-
])
|
55 |
-
return equation
|
56 |
|
57 |
-
elif group is scopes:
|
58 |
-
scope_type, scope_group = random.choice(list(scopes.items()))
|
59 |
-
scope_operator = random.choice(scope_group)
|
60 |
-
equation += scope_operator
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
])
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
74 |
])
|
|
|
75 |
|
76 |
-
elif
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
_generate_equation(size_left // 2, depth_left - 1, latex, tokens),
|
81 |
-
special.right_bracket,
|
82 |
-
special.underscore,
|
83 |
-
special.left_bracket,
|
84 |
-
_generate_equation(size_left // 2, depth_left - 1, latex, tokens)
|
85 |
-
])
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
_generate_equation(post_scope_size, depth_left, latex, tokens)
|
90 |
-
])
|
91 |
-
return equation
|
92 |
|
|
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
:latex: -- dict with tokens to generate equation from
|
100 |
-
:size: -- approximate size of equation
|
101 |
-
:depth: -- max brackets and scope depth
|
102 |
-
"""
|
103 |
-
tokens = [token for group in ['chars', 'greek', 'functions', 'operators', 'spaces']
|
104 |
-
for token in latex[group]]
|
105 |
-
equation = _generate_equation(size, depth, latex, tokens)
|
106 |
-
return equation
|
107 |
|
|
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
"""
|
114 |
Generates a random tex file and corresponding image
|
115 |
-------
|
@@ -117,41 +84,47 @@ def generate_image(directory: str, latex: dict, filename: str, max_length=20, eq
|
|
117 |
:directory: -- dir where to save files
|
118 |
:latex: -- dict with parameters to generate tex
|
119 |
:filename: -- absolute filename for the generated files
|
120 |
-
:
|
121 |
-
:
|
122 |
-
:
|
123 |
-
:ghostscript: -- path to ghostscript
|
124 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
filepath = os.path.join(directory, filename)
|
126 |
-
|
127 |
-
latex = DotDict(latex)
|
128 |
-
template = string.Template(latex.template)
|
129 |
-
font, font_options = random.choice(latex.fonts)
|
130 |
-
font_option = random.choice([''] + font_options)
|
131 |
-
fontsize = random.choice(latex.fontsizes)
|
132 |
-
equation = generate_equation(latex, equation_length, depth=equation_depth)
|
133 |
-
tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
|
134 |
-
|
135 |
-
with open(f"{filepath}.tex", mode='w') as file:
|
136 |
file.write(tex)
|
137 |
|
138 |
try:
|
139 |
pdflatex_process = subprocess.run(
|
140 |
-
f"{
|
141 |
stderr=subprocess.DEVNULL,
|
142 |
stdout=subprocess.DEVNULL,
|
143 |
timeout=1
|
144 |
)
|
145 |
except subprocess.TimeoutExpired:
|
146 |
-
|
147 |
return
|
148 |
|
149 |
if pdflatex_process.returncode != 0:
|
150 |
-
|
151 |
return
|
152 |
|
153 |
subprocess.run(
|
154 |
-
f"{
|
|
|
155 |
stderr=subprocess.DEVNULL,
|
156 |
stdout=subprocess.DEVNULL,
|
157 |
)
|
@@ -161,41 +134,33 @@ def _generate_image_wrapper(args):
|
|
161 |
return generate_image(*args)
|
162 |
|
163 |
|
164 |
-
def generate_data(examples_count) -> None:
|
165 |
"""
|
166 |
Clears a directory and generates a latex dataset in given directory
|
167 |
-
-------
|
168 |
-
params:
|
169 |
-
:examples_count: - how many latex - image examples to generate
|
170 |
"""
|
171 |
|
172 |
-
filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count))
|
173 |
directory = os.path.abspath(DATA_DIR)
|
174 |
-
|
175 |
-
|
176 |
-
latex = json.load(file)
|
177 |
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
os.path.abspath(file) for file in os.listdir(os.getcwd()))
|
184 |
|
185 |
-
files_before = _get_current_relevant_files()
|
186 |
while filenames:
|
187 |
with Pool() as pool:
|
188 |
list(tqdm.tqdm(
|
189 |
-
pool.imap(_generate_image_wrapper,
|
|
|
|
|
190 |
"Generating images",
|
191 |
total=len(filenames)
|
192 |
))
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
not file.endswith('.png') and not file.endswith('.tex'))
|
200 |
-
if files_to_delete:
|
201 |
-
subprocess.run(['rm'] + files_to_delete)
|
|
|
7 |
import random
|
8 |
import tqdm
|
9 |
|
10 |
+
DATA_DIR = "data"
|
11 |
+
LATEX_PATH = "resources/latex.json"
|
12 |
+
PDFLATEX = "/external2/dkkoshman/venv/texlive/2022/bin/x86_64-linux/pdflatex"
|
13 |
+
GHOSTSCRIPT = "/external2/dkkoshman/venv/local/gs/bin/gs"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
def generate_equation(latex, size, max_depth):
|
17 |
+
"""
|
18 |
+
Generates a random latex equation
|
19 |
+
-------
|
20 |
+
params:
|
21 |
+
:latex: -- dict with tokens to generate equation from
|
22 |
+
:size: -- approximate size of equation
|
23 |
+
:max_depth: -- max brackets and scope depth
|
24 |
+
"""
|
25 |
+
|
26 |
+
tokens, pairs, scopes = latex["tokens"], latex["pairs"], latex["scope_manipulators"]
|
27 |
+
|
28 |
+
def _generate_equation_recursive(size_left=size, depth_used=0):
|
29 |
+
if size_left <= 0:
|
30 |
+
return ""
|
31 |
+
|
32 |
+
equation = ""
|
33 |
+
group, = random.choices([tokens, pairs, scopes],
|
34 |
+
weights=[max_depth + 1, max_depth > depth_used, max_depth > depth_used])
|
35 |
+
|
36 |
+
if group is tokens:
|
37 |
+
equation += " ".join([
|
38 |
+
random.choice(tokens),
|
39 |
+
_generate_equation_recursive(size_left - 1, depth_used)
|
40 |
])
|
41 |
+
return equation
|
42 |
+
|
43 |
+
post_scope_size = round(abs(random.gauss(0, size_left / 2)))
|
44 |
+
size_left -= post_scope_size + 1
|
45 |
+
|
46 |
+
if group is pairs:
|
47 |
+
pair = random.choice(pairs)
|
48 |
+
equation += " ".join([
|
49 |
+
pair[0],
|
50 |
+
_generate_equation_recursive(size_left, depth_used + 1),
|
51 |
+
pair[1],
|
52 |
+
_generate_equation_recursive(post_scope_size, depth_used)
|
53 |
])
|
54 |
+
return equation
|
55 |
|
56 |
+
elif group is scopes:
|
57 |
+
scope_type, scope_group = random.choice(list(scopes.items()))
|
58 |
+
scope_operator = random.choice(scope_group)
|
59 |
+
equation += scope_operator
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
if scope_type == "single":
|
62 |
+
equation += "{ " + _generate_equation_recursive(size_left, depth_used + 1)
|
|
|
|
|
|
|
63 |
|
64 |
+
elif scope_type == "double_no_delimiters":
|
65 |
+
equation += "{ " + _generate_equation_recursive(size_left // 2, depth_used + 1) + " } { " + \
|
66 |
+
_generate_equation_recursive(size_left // 2, depth_used + 1)
|
67 |
|
68 |
+
elif scope_type == "double_with_delimiters":
|
69 |
+
equation += "^ { " + _generate_equation_recursive(size_left // 2, depth_used + 1) + " } _ { " + \
|
70 |
+
_generate_equation_recursive(size_left // 2, depth_used + 1)
|
71 |
+
|
72 |
+
equation += _generate_equation_recursive(post_scope_size, depth_used) + " }"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
return equation
|
75 |
|
76 |
+
return _generate_equation_recursive()
|
77 |
+
|
78 |
+
|
79 |
+
def generate_image(directory, latex, filename, max_depth, equation_length, distribution_fraction):
|
80 |
"""
|
81 |
Generates a random tex file and corresponding image
|
82 |
-------
|
|
|
84 |
:directory: -- dir where to save files
|
85 |
:latex: -- dict with parameters to generate tex
|
86 |
:filename: -- absolute filename for the generated files
|
87 |
+
:max_depth: -- max nested level of tex scopes
|
88 |
+
:equation_length: -- max length of equation
|
89 |
+
:distribution_fraction: -- fraction of whole available tex tokens to use
|
|
|
90 |
"""
|
91 |
+
fracture = lambda sequence: sequence[:max(1, int(len(sequence) * distribution_fraction))]
|
92 |
+
for group in ["tokens", "pairs", "fonts", "font_sizes"]:
|
93 |
+
latex[group] = fracture(latex[group])
|
94 |
+
for key, value in list(latex["scope_manipulators"].items()):
|
95 |
+
latex["scope_manipulators"]['key'] = fracture(value)
|
96 |
+
|
97 |
+
size = random.randint((equation_length + 1) // 2, equation_length)
|
98 |
+
equation = generate_equation(latex, size=size, max_depth=max_depth)
|
99 |
+
|
100 |
+
font, font_options = random.choice(latex["fonts"])
|
101 |
+
font_option = random.choice([""] + font_options)
|
102 |
+
font_size = random.choice(latex["font_sizes"])
|
103 |
+
template = string.Template(latex["template"])
|
104 |
+
tex = template.substitute(font=font, font_option=font_option, fontsize=font_size, equation=equation)
|
105 |
+
|
106 |
filepath = os.path.join(directory, filename)
|
107 |
+
with open(f"{filepath}.tex", mode="w") as file:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
file.write(tex)
|
109 |
|
110 |
try:
|
111 |
pdflatex_process = subprocess.run(
|
112 |
+
f"{PDFLATEX} -output-directory={directory} {filepath}.tex".split(),
|
113 |
stderr=subprocess.DEVNULL,
|
114 |
stdout=subprocess.DEVNULL,
|
115 |
timeout=1
|
116 |
)
|
117 |
except subprocess.TimeoutExpired:
|
118 |
+
os.remove(filepath + ".tex")
|
119 |
return
|
120 |
|
121 |
if pdflatex_process.returncode != 0:
|
122 |
+
os.remove(filepath + ".tex")
|
123 |
return
|
124 |
|
125 |
subprocess.run(
|
126 |
+
f"{GHOSTSCRIPT} -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE"
|
127 |
+
f" -o {filepath}.png {filepath}.pdf".split(),
|
128 |
stderr=subprocess.DEVNULL,
|
129 |
stdout=subprocess.DEVNULL,
|
130 |
)
|
|
|
134 |
return generate_image(*args)
|
135 |
|
136 |
|
137 |
+
def generate_data(examples_count, max_depth, equation_length, distribution_fraction) -> None:
|
138 |
"""
|
139 |
Clears a directory and generates a latex dataset in given directory
|
|
|
|
|
|
|
140 |
"""
|
141 |
|
|
|
142 |
directory = os.path.abspath(DATA_DIR)
|
143 |
+
shutil.rmtree(DATA_DIR)
|
144 |
+
os.mkdir(DATA_DIR)
|
|
|
145 |
|
146 |
+
with open(LATEX_PATH) as file:
|
147 |
+
latex = json.load(file)
|
148 |
|
149 |
+
filenames = set(f"{i:0{len(str(examples_count - 1))}d}" for i in range(examples_count))
|
150 |
+
files_before = set(os.listdir())
|
|
|
151 |
|
|
|
152 |
while filenames:
|
153 |
with Pool() as pool:
|
154 |
list(tqdm.tqdm(
|
155 |
+
pool.imap(_generate_image_wrapper,
|
156 |
+
((directory, latex, filename, max_depth, equation_length, distribution_fraction) for filename
|
157 |
+
in sorted(filenames))),
|
158 |
"Generating images",
|
159 |
total=len(filenames)
|
160 |
))
|
161 |
+
filenames -= set(
|
162 |
+
os.path.splitext(filename)[0] for filename in os.listdir(directory) if filename.endswith(".png"))
|
163 |
+
|
164 |
+
for file in set(i.path for i in os.scandir(DATA_DIR)) | set(os.listdir()) - files_before:
|
165 |
+
if any(file.endswith(ext) for ext in [".aux", ".pdf", ".log", ".sh"]):
|
166 |
+
os.remove(file)
|
|
|
|
|
|
data_preprocessing.py
CHANGED
@@ -12,10 +12,7 @@ import tqdm
|
|
12 |
import random
|
13 |
import re
|
14 |
|
15 |
-
|
16 |
-
IMAGE_WIDTH = 1024
|
17 |
-
IMAGE_HEIGHT = 128
|
18 |
-
BATCH_SIZE = 16
|
19 |
NUM_WORKERS = 4
|
20 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
21 |
PIN_MEMORY = False # probably causes cuda oom error if True
|
@@ -60,22 +57,6 @@ class TexImageDataset(Dataset):
|
|
60 |
return {"image": image, "tex": tex}
|
61 |
|
62 |
|
63 |
-
def generate_normalize_transform(dataset: TexImageDataset):
|
64 |
-
"""Returns a normalize layer with mean and std computed after iterating over dataset"""
|
65 |
-
|
66 |
-
mean = 0
|
67 |
-
std = 0
|
68 |
-
for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
|
69 |
-
image = item['image']
|
70 |
-
mean += image.mean()
|
71 |
-
std += image.std()
|
72 |
-
|
73 |
-
mean /= len(dataset)
|
74 |
-
std /= len(dataset)
|
75 |
-
normalize = T.Normalize(mean, std)
|
76 |
-
return normalize
|
77 |
-
|
78 |
-
|
79 |
class BatchCollator(object):
|
80 |
"""Image, tex batch collator"""
|
81 |
|
@@ -94,39 +75,30 @@ class BatchCollator(object):
|
|
94 |
return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
|
95 |
|
96 |
|
97 |
-
class StandardizeImageTransform(object):
|
98 |
-
"""Pad and crop image to a given size, grayscale and invert"""
|
99 |
-
|
100 |
-
def __init__(self, width=IMAGE_WIDTH, height=IMAGE_HEIGHT):
|
101 |
-
self.standardize = T.Compose((
|
102 |
-
T.Resize(height),
|
103 |
-
T.Grayscale(),
|
104 |
-
T.functional.invert,
|
105 |
-
T.CenterCrop((height, width)),
|
106 |
-
T.ConvertImageDtype(torch.float32)
|
107 |
-
))
|
108 |
-
|
109 |
-
def __call__(self, image):
|
110 |
-
image = self.standardize(image)
|
111 |
-
return image
|
112 |
-
|
113 |
-
|
114 |
class RandomizeImageTransform(object):
|
115 |
"""Standardize image and randomly augment"""
|
116 |
|
117 |
-
def __init__(self, width
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def __call__(self, image):
|
132 |
image = self.transform(image)
|
@@ -148,7 +120,7 @@ class ExtractEquationFromTexTransform(object):
|
|
148 |
return equation
|
149 |
|
150 |
|
151 |
-
def generate_tex_tokenizer(dataloader
|
152 |
"""Returns a tokenizer trained on texs from given dataset"""
|
153 |
|
154 |
texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
|
@@ -156,7 +128,6 @@ def generate_tex_tokenizer(dataloader, vocab_size):
|
|
156 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
157 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
158 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
159 |
-
vocab_size=vocab_size,
|
160 |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
161 |
)
|
162 |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
@@ -174,34 +145,30 @@ def generate_tex_tokenizer(dataloader, vocab_size):
|
|
174 |
|
175 |
|
176 |
class LatexImageDataModule(pl.LightningDataModule):
|
177 |
-
def __init__(self, batch_size
|
178 |
super().__init__()
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
self.train_dataset = TexImageDataset(
|
183 |
-
root_dir=DATA_DIR,
|
184 |
-
image_transform=RandomizeImageTransform(),
|
185 |
-
tex_transform=ExtractEquationFromTexTransform()
|
186 |
-
)
|
187 |
-
self.val_dataset = TexImageDataset(
|
188 |
-
root_dir=DATA_DIR,
|
189 |
-
image_transform=RandomizeImageTransform(),
|
190 |
-
tex_transform=ExtractEquationFromTexTransform()
|
191 |
-
)
|
192 |
-
self.test_dataset = TexImageDataset(
|
193 |
-
root_dir=DATA_DIR,
|
194 |
-
image_transform=RandomizeImageTransform(),
|
195 |
-
tex_transform=ExtractEquationFromTexTransform()
|
196 |
-
)
|
197 |
train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
|
198 |
self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
|
199 |
self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
|
200 |
self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
|
201 |
|
202 |
-
self.
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
self.collate_fn = BatchCollator(self.tex_tokenizer)
|
206 |
|
207 |
@staticmethod
|
@@ -213,8 +180,9 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
213 |
return indices[:train_split], indices[train_split: val_split], indices[val_split:]
|
214 |
|
215 |
def train_dataloader(self):
|
216 |
-
return DataLoader(self.train_dataset, batch_size=self.batch_size,
|
217 |
-
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS
|
|
|
218 |
|
219 |
def val_dataloader(self):
|
220 |
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
|
|
12 |
import random
|
13 |
import re
|
14 |
|
15 |
+
TOKENIZER_PATH = "resources/tokenizer.pt"
|
|
|
|
|
|
|
16 |
NUM_WORKERS = 4
|
17 |
PERSISTENT_WORKERS = True # whether to shut down workers at the end of epoch
|
18 |
PIN_MEMORY = False # probably causes cuda oom error if True
|
|
|
57 |
return {"image": image, "tex": tex}
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
class BatchCollator(object):
|
61 |
"""Image, tex batch collator"""
|
62 |
|
|
|
75 |
return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
class RandomizeImageTransform(object):
|
79 |
"""Standardize image and randomly augment"""
|
80 |
|
81 |
+
def __init__(self, width, height, random_magnitude):
|
82 |
+
if random_magnitude > 0:
|
83 |
+
self.transform = T.Compose((
|
84 |
+
T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
|
85 |
+
saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
|
86 |
+
T.Resize(height),
|
87 |
+
T.Grayscale(),
|
88 |
+
T.functional.invert,
|
89 |
+
T.CenterCrop((height, width)),
|
90 |
+
torch.Tensor.contiguous,
|
91 |
+
T.RandAugment(magnitude=random_magnitude),
|
92 |
+
T.ConvertImageDtype(torch.float32)
|
93 |
+
))
|
94 |
+
else:
|
95 |
+
self.transform = T.Compose((
|
96 |
+
T.Resize(height),
|
97 |
+
T.Grayscale(),
|
98 |
+
T.functional.invert,
|
99 |
+
T.CenterCrop((height, width)),
|
100 |
+
T.ConvertImageDtype(torch.float32)
|
101 |
+
))
|
102 |
|
103 |
def __call__(self, image):
|
104 |
image = self.transform(image)
|
|
|
120 |
return equation
|
121 |
|
122 |
|
123 |
+
def generate_tex_tokenizer(dataloader):
|
124 |
"""Returns a tokenizer trained on texs from given dataset"""
|
125 |
|
126 |
texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
|
|
|
128 |
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
129 |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
|
130 |
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
|
|
|
131 |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
132 |
)
|
133 |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
|
|
|
145 |
|
146 |
|
147 |
class LatexImageDataModule(pl.LightningDataModule):
|
148 |
+
def __init__(self, image_width, image_height, batch_size, random_magnitude):
|
149 |
super().__init__()
|
150 |
+
image_transform = RandomizeImageTransform(image_width, image_height, random_magnitude)
|
151 |
+
tex_transform = ExtractEquationFromTexTransform()
|
152 |
+
|
153 |
+
self.train_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
154 |
+
self.val_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
155 |
+
self.test_dataset = TexImageDataset(DATA_DIR, image_transform, tex_transform)
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
train_indices, val_indices, test_indices = self.train_val_test_split(len(self.train_dataset))
|
158 |
self.train_dataset = torch.utils.data.Subset(self.train_dataset, train_indices)
|
159 |
self.val_dataset = torch.utils.data.Subset(self.val_dataset, val_indices)
|
160 |
self.test_dataset = torch.utils.data.Subset(self.test_dataset, test_indices)
|
161 |
|
162 |
+
self.batch_size = batch_size
|
163 |
+
self.save_hyperparameters()
|
164 |
+
|
165 |
+
def prepare_data(self):
|
166 |
+
tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
|
167 |
+
print(f"Vocabulary size: {tokenizer.get_vocab_size()}")
|
168 |
+
torch.save(tokenizer, TOKENIZER_PATH)
|
169 |
+
|
170 |
+
def setup(self, stage=None):
|
171 |
+
self.tex_tokenizer = torch.load(TOKENIZER_PATH)
|
172 |
self.collate_fn = BatchCollator(self.tex_tokenizer)
|
173 |
|
174 |
@staticmethod
|
|
|
180 |
return indices[:train_split], indices[train_split: val_split], indices[val_split:]
|
181 |
|
182 |
def train_dataloader(self):
|
183 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
184 |
+
pin_memory=PIN_MEMORY, num_workers=NUM_WORKERS, persistent_workers=PERSISTENT_WORKERS,
|
185 |
+
shuffle=True)
|
186 |
|
187 |
def val_dataloader(self):
|
188 |
return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn,
|
generate.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data_generator import generate_data
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
|
6 |
+
def parse_args():
|
7 |
+
parser = argparse.ArgumentParser(description="Clear old dataset and generate new one")
|
8 |
+
parser.add_argument("size", help="size of new dataset", type=int)
|
9 |
+
parser.add_argument("depth", help="max_depth scope depth of generated equation, no less than 1", type=int)
|
10 |
+
parser.add_argument("length", help="length of equation will be in range length/2..length", type=int)
|
11 |
+
parser.add_argument("fraction", help="fraction of tex vocab to sample tokens from, float in range 0..1", type=float)
|
12 |
+
args = parser.parse_args()
|
13 |
+
return args
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
args = parse_args()
|
18 |
+
generate_data(examples_count=args.size, max_depth=args.depth, equation_length=args.length,
|
19 |
+
distribution_fraction=args.fraction)
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
main()
|
model.py
CHANGED
@@ -11,10 +11,6 @@ class AddPositionalEncoding(nn.Module):
|
|
11 |
def __init__(self, d_model, max_sequence_len=5000):
|
12 |
super().__init__()
|
13 |
|
14 |
-
# pos - position in sequence, i - index of element embedding
|
15 |
-
# PE(pos, 2i) = sin(pos / 10000**(2i / d_model)) = sin(pos * e**(2i * (-log(10000))/d_model))
|
16 |
-
# PE(pos, 2i+1) = cos(pos / 10000**(2i / d_model)) = cos(pos * e**(2i * (-log(10000))/d_model))
|
17 |
-
|
18 |
positions = torch.arange(max_sequence_len)
|
19 |
even_embedding_indices = torch.arange(0, d_model, 2)
|
20 |
|
@@ -103,7 +99,7 @@ class Transformer(pl.LightningModule):
|
|
103 |
def __init__(self,
|
104 |
num_encoder_layers: int,
|
105 |
num_decoder_layers: int,
|
106 |
-
|
107 |
nhead: int,
|
108 |
image_width: int,
|
109 |
image_height: int,
|
@@ -114,7 +110,7 @@ class Transformer(pl.LightningModule):
|
|
114 |
):
|
115 |
super().__init__()
|
116 |
|
117 |
-
self.transformer = nn.Transformer(d_model=
|
118 |
nhead=nhead,
|
119 |
num_encoder_layers=num_encoder_layers,
|
120 |
num_decoder_layers=num_decoder_layers,
|
@@ -125,10 +121,10 @@ class Transformer(pl.LightningModule):
|
|
125 |
if p.dim() > 1:
|
126 |
nn.init.xavier_uniform_(p)
|
127 |
|
128 |
-
self.d_model =
|
129 |
-
self.src_tok_emb = ImageEmbedding(
|
130 |
-
self.tgt_tok_emb = TexEmbedding(
|
131 |
-
self.generator = nn.Linear(
|
132 |
# Make embedding and generator share weight because they do the same thing
|
133 |
self.tgt_tok_emb.embedding.weight = self.generator.weight
|
134 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
|
|
|
11 |
def __init__(self, d_model, max_sequence_len=5000):
|
12 |
super().__init__()
|
13 |
|
|
|
|
|
|
|
|
|
14 |
positions = torch.arange(max_sequence_len)
|
15 |
even_embedding_indices = torch.arange(0, d_model, 2)
|
16 |
|
|
|
99 |
def __init__(self,
|
100 |
num_encoder_layers: int,
|
101 |
num_decoder_layers: int,
|
102 |
+
d_model: int,
|
103 |
nhead: int,
|
104 |
image_width: int,
|
105 |
image_height: int,
|
|
|
110 |
):
|
111 |
super().__init__()
|
112 |
|
113 |
+
self.transformer = nn.Transformer(d_model=d_model,
|
114 |
nhead=nhead,
|
115 |
num_encoder_layers=num_encoder_layers,
|
116 |
num_decoder_layers=num_decoder_layers,
|
|
|
121 |
if p.dim() > 1:
|
122 |
nn.init.xavier_uniform_(p)
|
123 |
|
124 |
+
self.d_model = d_model
|
125 |
+
self.src_tok_emb = ImageEmbedding(d_model, image_width, image_height, patch_size=16, dropout=dropout)
|
126 |
+
self.tgt_tok_emb = TexEmbedding(d_model, tgt_vocab_size, dropout=dropout)
|
127 |
+
self.generator = nn.Linear(d_model, tgt_vocab_size)
|
128 |
# Make embedding and generator share weight because they do the same thing
|
129 |
self.tgt_tok_emb.embedding.weight = self.generator.weight
|
130 |
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=.1)
|
train.py
CHANGED
@@ -1,73 +1,57 @@
|
|
1 |
-
from
|
2 |
-
from data_preprocessing import LatexImageDataModule, IMAGE_WIDTH, IMAGE_HEIGHT
|
3 |
from model import Transformer
|
4 |
from utils import LogImageTexCallback
|
5 |
|
6 |
import argparse
|
|
|
7 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
from pytorch_lightning.loggers import WandbLogger
|
9 |
-
from pytorch_lightning import Trainer
|
10 |
import torch
|
11 |
|
12 |
-
DATASET_PATH = "resources/dataset.pt"
|
13 |
TRAINER_DIR = "resources/pl_trainer_checkpoints"
|
14 |
-
TUNER_DIR = "resources/pl_tuner_checkpoints"
|
15 |
-
BEST_MODEL_CHECKPOINT = "best_model.ckpt"
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
18 |
def parse_args():
|
19 |
-
parser = argparse.ArgumentParser()
|
20 |
-
|
21 |
-
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
)
|
27 |
-
parser.add_argument(
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
"-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
|
37 |
-
action="store_true", dest="deterministic"
|
38 |
-
)
|
39 |
-
# parser.add_argument(
|
40 |
-
# "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
|
41 |
-
# action="store_true", dest="tune"
|
42 |
-
# )
|
43 |
|
44 |
args = parser.parse_args()
|
|
|
|
|
|
|
45 |
return args
|
46 |
|
47 |
|
48 |
-
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
|
49 |
-
# crop image, adjust brightness, lr warmup?, make tex tokens always decodable,
|
50 |
-
# take loss that doesn't punish so much for offsets, take a look at weights,
|
51 |
-
|
52 |
def main():
|
53 |
args = parse_args()
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
if args.new_dataset is not None:
|
59 |
-
generate_data(args.new_dataset)
|
60 |
-
datamodule = LatexImageDataModule()
|
61 |
-
torch.save(datamodule, DATASET_PATH)
|
62 |
-
else:
|
63 |
-
datamodule = torch.load(DATASET_PATH)
|
64 |
-
|
65 |
if args.log:
|
66 |
logger = WandbLogger(f"img2tex", log_model=True)
|
67 |
-
callbacks = [
|
68 |
-
|
69 |
-
LearningRateMonitor(logging_interval='step')
|
70 |
-
]
|
71 |
else:
|
72 |
logger = None
|
73 |
callbacks = []
|
@@ -79,24 +63,22 @@ def main():
|
|
79 |
strategy="ddp",
|
80 |
enable_progress_bar=True,
|
81 |
default_root_dir=TRAINER_DIR,
|
82 |
-
callbacks=callbacks
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
91 |
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
|
92 |
-
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]")
|
93 |
-
dim_feedforward=512,
|
94 |
-
dropout=0.1,
|
95 |
-
)
|
96 |
|
97 |
trainer.fit(transformer, datamodule=datamodule)
|
98 |
-
trainer.test(datamodule=datamodule)
|
99 |
-
trainer.save_checkpoint(
|
100 |
|
101 |
|
102 |
if __name__ == "__main__":
|
|
|
1 |
+
from data_preprocessing import LatexImageDataModule
|
|
|
2 |
from model import Transformer
|
3 |
from utils import LogImageTexCallback
|
4 |
|
5 |
import argparse
|
6 |
+
import os
|
7 |
from pytorch_lightning.callbacks import LearningRateMonitor
|
8 |
from pytorch_lightning.loggers import WandbLogger
|
9 |
+
from pytorch_lightning import Trainer
|
10 |
import torch
|
11 |
|
|
|
12 |
TRAINER_DIR = "resources/pl_trainer_checkpoints"
|
|
|
|
|
13 |
|
14 |
|
15 |
+
# TODO: update python, maybe model doesnt train bc of ignore special index in CrossEntropyLoss?
|
16 |
+
# crop image, adjust brightness, make tex tokens always decodable,
|
17 |
+
# save only datamodule state?, ensemble last checkpoints, early stopping
|
18 |
+
|
19 |
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser(allow_abbrev=True, formatter_class=argparse.RawTextHelpFormatter)
|
21 |
+
|
22 |
+
parser.add_argument("-m", "-max-epochs", help="limit the number of training epochs", type=int, dest="max_epochs")
|
23 |
+
parser.add_argument("-g", "-gpus", metavar="GPUS", type=int, choices=list(range(torch.cuda.device_count())),
|
24 |
+
help="ids of gpus to train on, if not provided, then trains on cpu", nargs="+", dest="gpus")
|
25 |
+
parser.add_argument("-l", "-log", help="whether to save logs of run to w&b logger, default False", default=False,
|
26 |
+
action="store_true", dest="log")
|
27 |
+
parser.add_argument("-width", help="width of images, default 1024", default=1024, type=int)
|
28 |
+
parser.add_argument("-height", help="height of images, default 128", default=128, type=int)
|
29 |
+
parser.add_argument("-r", "-randomize", default=5, type=int, dest="random_magnitude", choices=range(10),
|
30 |
+
help="add random augments to images of provided magnitude in range 0..9, default 5")
|
31 |
+
parser.add_argument("-b", "-batch-size", help="batch size, default 16", default=16,
|
32 |
+
type=int, dest="batch_size")
|
33 |
+
transformer_args = [("num_encoder_layers", 6), ("num_decoder_layers", 6), ("d_model", 512), ("nhead", 8),
|
34 |
+
("dim_feedforward", 2048), ("dropout", 0.1)]
|
35 |
+
parser.add_argument("-t", "-transformer-args", dest="transformer_args", nargs='+', default=[],
|
36 |
+
help="transformer init args:\n" + "\n".join(f"{k}\t{v}" for k, v in transformer_args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
args = parser.parse_args()
|
39 |
+
for i, parameter in enumerate(args.transformer_args):
|
40 |
+
transformer_args[i][1] = parameter
|
41 |
+
args.transformer_args = dict(transformer_args)
|
42 |
return args
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
45 |
def main():
|
46 |
args = parse_args()
|
47 |
+
datamodule = LatexImageDataModule(image_width=args.width, image_height=args.height,
|
48 |
+
batch_size=args.batch_size, random_magnitude=args.random_magnitude)
|
49 |
+
datamodule.prepare_data()
|
50 |
+
datamodule.setup()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
if args.log:
|
52 |
logger = WandbLogger(f"img2tex", log_model=True)
|
53 |
+
callbacks = [LogImageTexCallback(logger, datamodule.tex_tokenizer),
|
54 |
+
LearningRateMonitor(logging_interval='step')]
|
|
|
|
|
55 |
else:
|
56 |
logger = None
|
57 |
callbacks = []
|
|
|
63 |
strategy="ddp",
|
64 |
enable_progress_bar=True,
|
65 |
default_root_dir=TRAINER_DIR,
|
66 |
+
callbacks=callbacks)
|
67 |
+
|
68 |
+
transformer = Transformer(num_encoder_layers=args.transformer_args['num_encoder_layers'],
|
69 |
+
num_decoder_layers=args.transformer_args['num_decoder_layers'],
|
70 |
+
d_model=args.transformer_args['d_model'],
|
71 |
+
nhead=args.transformer_args['nhead'],
|
72 |
+
dim_feedforward=args.transformer_args['dim_feedforward'],
|
73 |
+
dropout=args.transformer_args['dropout'],
|
74 |
+
image_width=datamodule.hparams['image_width'],
|
75 |
+
image_height=datamodule.hparams['image_height'],
|
76 |
tgt_vocab_size=datamodule.tex_tokenizer.get_vocab_size(),
|
77 |
+
pad_idx=datamodule.tex_tokenizer.token_to_id("[PAD]"))
|
|
|
|
|
|
|
78 |
|
79 |
trainer.fit(transformer, datamodule=datamodule)
|
80 |
+
trainer.test(datamodule=datamodule, ckpt_path='best')
|
81 |
+
trainer.save_checkpoint(os.path.join(TRAINER_DIR, "best_model.ckpt"))
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
utils.py
CHANGED
@@ -22,57 +22,84 @@ class LogImageTexCallback(Callback):
|
|
22 |
image = self.tensor_to_PIL(image)
|
23 |
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
|
24 |
skip_special_tokens=True)
|
25 |
-
self.logger.log_image(key="samples", images=[image],
|
26 |
-
caption=[f"True: {tex_true}\nPredicted: {tex_predicted}\nIds: {tex_ids}"])
|
27 |
|
28 |
-
# if args.new_dataset:
|
29 |
-
# datamodule.batch_size = 1
|
30 |
-
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
|
31 |
-
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
|
32 |
-
# gpus=args.gpus,
|
33 |
-
# strategy=TRAINER_STRATEGY,
|
34 |
-
# enable_progress_bar=True,
|
35 |
-
# enable_checkpointing=False,
|
36 |
-
# auto_scale_batch_size=True,
|
37 |
-
# num_sanity_val_steps=0,
|
38 |
-
# logger=False
|
39 |
-
# )
|
40 |
-
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
|
41 |
-
# torch.save(datamodule, DATASET_PATH)
|
42 |
-
|
43 |
-
class _TransformerTuner(Transformer):
|
44 |
-
"""
|
45 |
-
When using trainer.tune, batches from dataloader get passed directly to forward,
|
46 |
-
so this subclass takes care of that
|
47 |
-
"""
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
tgt_output = tgt[:, 1:]
|
54 |
-
src_mask = None
|
55 |
-
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
56 |
-
torch.ByteTensor.dtype)
|
57 |
-
memory_mask = None
|
58 |
-
src_padding_mask = None
|
59 |
-
tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
60 |
-
tgt_padding_mask = tgt_padding_mask.masked_fill(
|
61 |
-
tgt_padding_mask == 0, float('-inf')
|
62 |
-
).masked_fill(
|
63 |
-
tgt_padding_mask == 1, 0
|
64 |
-
)
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
-
def validation_step(self, batch, batch_idx):
|
75 |
-
return self(batch, batch_idx)
|
76 |
|
77 |
@torch.inference_mode()
|
78 |
def decode(transformer, tex_tokenizer, image):
|
@@ -87,4 +114,4 @@ def decode(transformer, tex_tokenizer, image):
|
|
87 |
next_id = outs[0, :, -1].argmax().item()
|
88 |
tex_ids.append(next_id)
|
89 |
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
|
90 |
-
return tex, tex_ids
|
|
|
22 |
image = self.tensor_to_PIL(image)
|
23 |
tex_true = self.tex_tokenizer.decode(list(batch['tex_ids'][sample_id].to('cpu', torch.int)),
|
24 |
skip_special_tokens=True)
|
25 |
+
self.logger.log_image(key="samples", images=[image], caption=[f"True: {tex_true}\nPredicted: {tex_predicted}"])
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
# parser.add_argument(
|
29 |
+
# "-t", "-tune", help="whether to tune model for batch size before training, default False", default=False,
|
30 |
+
# action="store_true", dest="tune"
|
31 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
# if args.new_dataset:
|
34 |
+
# datamodule.batch_size = 1
|
35 |
+
# transformer_for_tuning = TransformerTuner(**transformer.hparams).cuda()
|
36 |
+
# tuner = Trainer(accelerator="gpu" if args.gpus else "cpu",
|
37 |
+
# gpus=args.gpus,
|
38 |
+
# strategy=TRAINER_STRATEGY,
|
39 |
+
# enable_progress_bar=True,
|
40 |
+
# enable_checkpointing=False,
|
41 |
+
# auto_scale_batch_size=True,
|
42 |
+
# num_sanity_val_steps=0,
|
43 |
+
# logger=False
|
44 |
+
# )
|
45 |
+
# tuner.tune(transformer_for_tuning, datamodule=datamodule)
|
46 |
+
# torch.save(datamodule, DATASET_PATH)
|
47 |
+
# TUNER_DIR = "resources/pl_tuner_checkpoints"
|
48 |
+
# from pytorch_lightning import seed_everything
|
49 |
+
# parser.add_argument(
|
50 |
+
# "-d", "-deterministic", help="whether to seed all rngs for reproducibility, default False", default=False,
|
51 |
+
# action="store_true", dest="deterministic"
|
52 |
+
# )
|
53 |
+
# if args.deterministic:
|
54 |
+
# seed_everything(42, workers=True)
|
55 |
+
# def generate_normalize_transform(dataset: TexImageDataset):
|
56 |
+
# """Returns a normalize layer with mean and std computed after iterating over dataset"""
|
57 |
+
#
|
58 |
+
# mean = 0
|
59 |
+
# std = 0
|
60 |
+
# for item in tqdm.tqdm(dataset, "Computing dataset image stats"):
|
61 |
+
# image = item['image']
|
62 |
+
# mean += image.mean()
|
63 |
+
# std += image.std()
|
64 |
+
#
|
65 |
+
# mean /= len(dataset)
|
66 |
+
# std /= len(dataset)
|
67 |
+
# normalize = T.Normalize(mean, std)
|
68 |
+
# return normalize
|
69 |
+
# class _TransformerTuner(Transformer):
|
70 |
+
# """
|
71 |
+
# When using trainer.tune, batches from dataloader get passed directly to forward,
|
72 |
+
# so this subclass takes care of that
|
73 |
+
# """
|
74 |
+
#
|
75 |
+
# def forward(self, batch, batch_idx):
|
76 |
+
# src = batch['images']
|
77 |
+
# tgt = batch['tex_ids']
|
78 |
+
# tgt_input = tgt[:, :-1]
|
79 |
+
# tgt_output = tgt[:, 1:]
|
80 |
+
# src_mask = None
|
81 |
+
# tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_input.shape[1]).to(self.device,
|
82 |
+
# torch.ByteTensor.dtype)
|
83 |
+
# memory_mask = None
|
84 |
+
# src_padding_mask = None
|
85 |
+
# tgt_padding_mask = batch['tex_attention_masks'][:, :-1]
|
86 |
+
# tgt_padding_mask = tgt_padding_mask.masked_fill(
|
87 |
+
# tgt_padding_mask == 0, float('-inf')
|
88 |
+
# ).masked_fill(
|
89 |
+
# tgt_padding_mask == 1, 0
|
90 |
+
# )
|
91 |
+
#
|
92 |
+
# src = self.src_tok_emb(src)
|
93 |
+
# tgt_input = self.tgt_tok_emb(tgt_input)
|
94 |
+
# outs = self.transformer(src, tgt_input, src_mask, tgt_mask, memory_mask, src_padding_mask, tgt_padding_mask)
|
95 |
+
# outs = self.generator(outs)
|
96 |
+
#
|
97 |
+
# loss = self.loss_fn(einops.rearrange(outs, 'b n prob -> b prob n'), tgt_output.long())
|
98 |
+
# return loss
|
99 |
+
#
|
100 |
+
# def validation_step(self, batch, batch_idx):
|
101 |
+
# return self(batch, batch_idx)
|
102 |
|
|
|
|
|
103 |
|
104 |
@torch.inference_mode()
|
105 |
def decode(transformer, tex_tokenizer, image):
|
|
|
114 |
next_id = outs[0, :, -1].argmax().item()
|
115 |
tex_ids.append(next_id)
|
116 |
tex = tex_tokenizer.decode(tex_ids, skip_special_tokens=True)
|
117 |
+
return tex, tex_ids
|