import os import re import tokenizers import torch import torchvision import torchvision.transforms as T import tqdm import PIL from torch.utils.data import Dataset, DataLoader directory = "/external2/dkkoshman/repos/ML2TransformerApp/data/" class TexImageDataset(Dataset): """Image to tex dataset.""" def __init__(self, root_dir, image_preprocessing=None, tex_preprocessing=None): """ Args: root_dir (string): Directory with all the images and tex files. transform (callable, optional): Optional transform to be applied on a sample. image_preprocessing: callable image preprocessing tex_preprocessing: callable tex preprocessing """ torch.multiprocessing.set_sharing_strategy('file_system') self.root_dir = root_dir filenames = sorted( set(os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png')) ) self.data = [] for filename in tqdm.tqdm(filenames): tex_path = self.root_dir + filename + '.tex' image_path = self.root_dir + filename + '.png' with open(tex_path) as file: tex = file.read() if tex_preprocessing: tex = tex_preprocessing(tex) image = torchvision.io.read_image(image_path) if image_preprocessing: image = image_preprocessing(image) self.data.append((image, tex)) def __len__(self): return len(self.data) def __getitem__(self, idx): image, tex = self.data[idx] return {"image": image, "tex": tex} class StandardizeImage(object): """Pad and crop image to a given size, invert and normalize""" def __init__(self, width=1024, height=128): self.transform = T.Compose(( T.Resize(height), T.Grayscale(), T.functional.invert, T.CenterCrop((height, width)) )) def __call__(self, image): image = self.transform(image) return image class RandomTransformImage(object): """Standardize image and randomly augment""" def __init__(self, standardize, random_magnitude=5): self.brighten = T.ColorJitter(brightness=(1/random_magnitude, 1 + 1/random_magnitude)) self.standardize = standardize self.rand_aug = T.RandAugment(magnitude=random_magnitude) def __call__(self, image): image = self.brighten(image) image = self.standardize(image) image = image.contiguous() image = self.rand_aug(image) return image def generate_tex_tokenizer(dataset): """Returns a tokeniser trained on tex strings from dataset""" tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]")) tokenizer_trainer = tokenizers.trainers.BpeTrainer( vocab_size=300, special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] ) tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace() tokenizer.train_from_iterator((item['tex'] for item in dataset), trainer=tokenizer_trainer) tokenizer.post_processor = tokenizers.processors.TemplateProcessing( single="$A [SEP]", special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))] ) tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]") return tokenizer