ML2TransformerApp / data_preprocessing.py
dkoshman
two line change
41c9661
raw
history blame
No virus
3.52 kB
import os
import re
import tokenizers
import torch
import torchvision
import torchvision.transforms as T
import tqdm
import PIL
from torch.utils.data import Dataset, DataLoader
directory = "/external2/dkkoshman/repos/ML2TransformerApp/data/"
class TexImageDataset(Dataset):
"""Image to tex dataset."""
def __init__(self, root_dir, image_preprocessing=None, tex_preprocessing=None):
"""
Args:
root_dir (string): Directory with all the images and tex files.
transform (callable, optional): Optional transform to be applied
on a sample.
image_preprocessing: callable image preprocessing
tex_preprocessing: callable tex preprocessing
"""
torch.multiprocessing.set_sharing_strategy('file_system')
self.root_dir = root_dir
filenames = sorted(
set(os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png'))
)
self.data = []
for filename in tqdm.tqdm(filenames):
tex_path = self.root_dir + filename + '.tex'
image_path = self.root_dir + filename + '.png'
with open(tex_path) as file:
tex = file.read()
if tex_preprocessing:
tex = tex_preprocessing(tex)
image = torchvision.io.read_image(image_path)
if image_preprocessing:
image = image_preprocessing(image)
self.data.append((image, tex))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image, tex = self.data[idx]
return {"image": image, "tex": tex}
class StandardizeImage(object):
"""Pad and crop image to a given size, invert and normalize"""
def __init__(self, width=1024, height=128):
self.transform = T.Compose((
T.Resize(height),
T.Grayscale(),
T.functional.invert,
T.CenterCrop((height, width))
))
def __call__(self, image):
image = self.transform(image)
return image
class RandomTransformImage(object):
"""Standardize image and randomly augment"""
def __init__(self, standardize, random_magnitude=5):
self.brighten = T.ColorJitter(brightness=(1/random_magnitude, 1 + 1/random_magnitude))
self.standardize = standardize
self.rand_aug = T.RandAugment(magnitude=random_magnitude)
def __call__(self, image):
image = self.brighten(image)
image = self.standardize(image)
image = image.contiguous()
image = self.rand_aug(image)
return image
def generate_tex_tokenizer(texs):
"""Returns a tokeniser trained on tex strings from dataset"""
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
vocab_size=300,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
single="$A [SEP]",
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
)
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
return tokenizer