ML2TransformerApp / data_generator.py
dkoshman
data_preprocessing, base train script
6e82d4a
raw
history blame
No virus
6.23 kB
import json
from multiprocessing import Pool
import os
import string
import subprocess
import random
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if len(args) > 0 and isinstance(args[0], dict):
for key, value in self.items():
if isinstance(value, dict):
self.__setitem__(key, DotDict(value))
def _generate_equation(size_left, depth_left, latex, tokens):
if size_left <= 0:
return ""
equation = ""
pairs, scopes, special = latex.pairs, latex.scopes, latex.special
weights = [3, depth_left > 0, depth_left > 0]
group, = random.choices([tokens, pairs, scopes], weights=weights)
if group is tokens:
equation += ' '.join([
random.choice(tokens),
_generate_equation(size_left - 1, depth_left, latex, tokens)
])
return equation
post_scope_size = round(abs(random.gauss(0, size_left / 2)))
size_left -= post_scope_size + 1
if group is pairs:
pair = random.choice(pairs)
equation += ' '.join([
pair[0],
_generate_equation(size_left, depth_left - 1, latex, tokens),
pair[1],
_generate_equation(post_scope_size, depth_left, latex, tokens)
])
return equation
elif group is scopes:
scope_type, scope_group = random.choice(list(scopes.items()))
scope_operator = random.choice(scope_group)
equation += scope_operator
if scope_type == 'single':
equation += ' '.join([
special.left_bracket,
_generate_equation(size_left, depth_left - 1, latex, tokens)
])
elif scope_type == 'double_no_delimiters':
equation += ' '.join([
special.left_bracket,
_generate_equation(size_left // 2, depth_left - 1, latex, tokens),
special.right_bracket + special.left_bracket,
_generate_equation(size_left // 2, depth_left - 1, latex, tokens)
])
elif scope_type == 'double_with_delimiters':
equation += ' '.join([
special.caret,
special.left_bracket,
_generate_equation(size_left // 2, depth_left - 1, latex, tokens),
special.right_bracket,
special.underscore,
special.left_bracket,
_generate_equation(size_left // 2, depth_left - 1, latex, tokens)
])
equation += ' '.join([
special.right_bracket,
_generate_equation(post_scope_size, depth_left, latex, tokens)
])
return equation
def generate_equation(latex: DotDict, size, depth=3):
"""
Generates a random latex equation
-------
params:
:latex: -- dict with tokens to generate equation from
:size: -- approximate size of equation
:depth: -- max brackets and scope depth
"""
tokens = [token for group in ['chars', 'greek', 'functions', 'operators', 'spaces']
for token in latex[group]]
equation = _generate_equation(size, depth, latex, tokens)
return equation
def generate_image(directory: str, latex_path: str, filename: str, max_length=20):
"""
Generates a random tex file and corresponding image
-------
params:
:directory: -- dir where to save files
:latex_dir: -- path to latex json
:filename: -- name for the generated files
:max_length: -- max size of equation
"""
# TODO ARGPARSE, path parse
filepath = directory + filename
with open(latex_path) as file:
latex = json.load(file)
latex = DotDict(latex)
template = string.Template(latex.template)
font, font_options = random.choice(latex.fonts)
font_option = random.choice([''] + font_options)
fontsize = random.choice(latex.fontsizes)
equation = generate_equation(latex, max_length)
tex = template.substitute(font=font, font_option=font_option, fontsize=fontsize, equation=equation)
files_before = set(os.listdir(directory))
with open(f"{filepath}.tex", mode='w') as file:
file.write(tex)
pr1 = subprocess.run(
f"pdflatex -output-directory={directory} {filepath}.tex".split(),
stderr=subprocess.PIPE,
)
files_after = set(os.listdir(directory))
if pr1.returncode != 0:
files_to_delete = files_after - files_before
if files_to_delete:
subprocess.run(['rm'] + [directory + file for file in files_to_delete])
print(pr1.stderr.decode(), tex)
return
pr2 = subprocess.run(
f"gs -sDEVICE=png16m -dTextAlphaBits=4 -r200 -dSAFER -dBATCH -dNOPAUSE -o {filepath}.png {filepath}.pdf".split(),
stderr=subprocess.PIPE,
)
files_to_delete = files_after - files_before - {filename + '.png', filename + '.tex'}
if files_to_delete:
subprocess.run(['rm'] + [directory + file for file in files_to_delete])
assert (pr2.returncode == 0)
def generate_dataset(
filenames: iter(str),
directory: str = "/external2/dkkoshman/repos/ML2TransformerApp/data/",
latex_path: str = "/external2/dkkoshman/repos/ML2TransformerApp/resources/latex.json",
overwrite: bool = False
) -> None:
"""
Generates a latex dataset in given directory
-------
params:
:filenames: - iterable of filenames to create, without extension
:directory: - where to create
:latex_path: - full path to latex json
:overwrite: - whether to overwrite existing files
"""
filenames = set(filenames)
if not overwrite:
existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
filenames -= existing
while filenames:
with Pool() as pool:
pool.starmap(generate_image, ((directory, latex_path, name) for name in filenames))
existing = set(file.split('.')[0] for file in os.listdir(directory) if file.endswith('.png'))
filenames -= existing