from constants import DATA_DIR, TOKENIZER_PATH, NUM_DATALOADER_WORKERS, PERSISTENT_WORKERS, PIN_MEMORY import einops import os import pytorch_lightning as pl import tokenizers import torch import torchvision import torchvision.transforms as T from torch.utils.data import Dataset, DataLoader import tqdm import re class TexImageDataset(Dataset): """Image and tex dataset.""" def __init__(self, root_dir, image_transform=None, tex_transform=None): """ Args: root_dir (string): Directory with all the images and tex files. image_transform: callable image preprocessing tex_transform: callable tex preprocessing """ torch.multiprocessing.set_sharing_strategy('file_system') self.root_dir = root_dir self.filenames = sorted(set( os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('.png') )) self.image_transform = image_transform self.tex_transform = tex_transform def __len__(self): return len(self.filenames) def __getitem__(self, idx): filename = self.filenames[idx] image_path = os.path.join(self.root_dir, filename + '.png') tex_path = os.path.join(self.root_dir, filename + '.tex') with open(tex_path) as file: tex = file.read() if self.tex_transform: tex = self.tex_transform(tex) image = torchvision.io.read_image(image_path) if self.image_transform: image = self.image_transform(image) return {"image": image, "tex": tex} class BatchCollator(object): """Image, tex batch collator""" def __init__(self, tokenizer): self.tokenizer = tokenizer def __call__(self, batch): images = [i['image'] for i in batch] images = einops.rearrange(images, 'b c h w -> b c h w') texs = [item['tex'] for item in batch] texs = self.tokenizer.encode_batch(texs) tex_ids = torch.Tensor([encoding.ids for encoding in texs]) attention_masks = torch.Tensor([encoding.attention_mask for encoding in texs]) return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks} class RandomizeImageTransform(object): """Standardize image and randomly augment""" def __init__(self, width, height, random_magnitude): self.transform = T.Compose(( lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10, saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)), T.Resize(height, max_size=width), T.Grayscale(), T.functional.invert, T.CenterCrop((height, width)), torch.Tensor.contiguous, lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude), T.ConvertImageDtype(torch.float32) )) def __call__(self, image): image = self.transform(image) return image class ExtractEquationFromTexTransform(object): """Extracts ...\[ equation \]... from tex file""" def __init__(self): self.equation_pattern = re.compile(r'\\\[(?P.*)\\\]', flags=re.DOTALL) self.spaces = re.compile(r' +') def __call__(self, tex): equation = self.equation_pattern.search(tex) equation = equation.group('equation') equation = equation.strip() equation = self.spaces.sub(' ', equation) return equation def generate_tex_tokenizer(dataloader): """Returns a tokenizer trained on texs from given dataset""" texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader))) tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]")) tokenizer_trainer = tokenizers.trainers.BpeTrainer( special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] ) tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace() tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer) tokenizer.post_processor = tokenizers.processors.TemplateProcessing( single="[CLS] $A [SEP]", special_tokens=[ ("[CLS]", tokenizer.token_to_id("[CLS]")), ("[SEP]", tokenizer.token_to_id("[SEP]")), ] ) tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]") return tokenizer class LatexImageDataModule(pl.LightningDataModule): def __init__(self, image_width, image_height, batch_size, random_magnitude): super().__init__() dataset = TexImageDataset(root_dir=DATA_DIR, image_transform=RandomizeImageTransform(image_width, image_height, random_magnitude), tex_transform=ExtractEquationFromTexTransform()) self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split( dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20]) self.batch_size = batch_size self.save_hyperparameters() def train_tokenizer(self): tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16)) torch.save(tokenizer, TOKENIZER_PATH) return tokenizer def _shared_dataloader(self, dataset, **kwargs): tex_tokenizer = torch.load(TOKENIZER_PATH) collate_fn = BatchCollator(tex_tokenizer) return DataLoader(dataset, batch_size=self.batch_size, collate_fn=collate_fn, pin_memory=PIN_MEMORY, num_workers=NUM_DATALOADER_WORKERS, persistent_workers=PERSISTENT_WORKERS, **kwargs) def train_dataloader(self): return self._shared_dataloader(self.train_dataset, shuffle=True) def val_dataloader(self): return self._shared_dataloader(self.val_dataset) def test_dataloader(self): return self._shared_dataloader(self.test_dataset)