Spaces:
Runtime error
Runtime error
File size: 6,320 Bytes
c308f77 fb8db0f 6e82d4a e33424f fb8db0f e33424f fb8db0f e33424f 6e82d4a e33424f e949d7b e33424f 6e82d4a e33424f 6e82d4a e33424f 6e82d4a e33424f 6e82d4a fb8db0f 6e82d4a e33424f 6e82d4a e33424f 6e82d4a e949d7b fb8db0f 6e82d4a fb8db0f e33424f 6e82d4a e949d7b 6e82d4a e949d7b 6e82d4a e33424f 41a34cd 96feb73 1b4da0d 96feb73 1b4da0d 96feb73 e33424f 6e82d4a e33424f 6e82d4a ae308b4 41a34cd fb8db0f ae308b4 57273ba fb8db0f ae308b4 fb8db0f ae308b4 fb8db0f 41a34cd 2a394f6 fb8db0f c308f77 1b4da0d c308f77 41a34cd 29bcc5f 41a34cd 29bcc5f 41a34cd c308f77 2a394f6 fb8db0f c308f77 fb8db0f c308f77 fb8db0f c308f77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from constants import DATA_DIR, TOKENIZER_PATH, NUM_DATALOADER_WORKERS, PERSISTENT_WORKERS, PIN_MEMORY
import einops
import os
import pytorch_lightning as pl
import tokenizers
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import tqdm
import re
class TexImageDataset(Dataset):
"""Image and tex dataset."""
def __init__(self, root_dir, image_transform=None, tex_transform=None):
"""
Args:
root_dir (string): Directory with all the images and tex files.
image_transform: callable image preprocessing
tex_transform: callable tex preprocessing
"""
torch.multiprocessing.set_sharing_strategy('file_system')
self.root_dir = root_dir
self.filenames = sorted(set(
os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('.png')
))
self.image_transform = image_transform
self.tex_transform = tex_transform
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename = self.filenames[idx]
image_path = os.path.join(self.root_dir, filename + '.png')
tex_path = os.path.join(self.root_dir, filename + '.tex')
with open(tex_path) as file:
tex = file.read()
if self.tex_transform:
tex = self.tex_transform(tex)
image = torchvision.io.read_image(image_path)
if self.image_transform:
image = self.image_transform(image)
return {"image": image, "tex": tex}
class BatchCollator(object):
"""Image, tex batch collator"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, batch):
images = [i['image'] for i in batch]
images = einops.rearrange(images, 'b c h w -> b c h w')
texs = [item['tex'] for item in batch]
texs = self.tokenizer.encode_batch(texs)
tex_ids = torch.Tensor([encoding.ids for encoding in texs])
attention_masks = torch.Tensor([encoding.attention_mask for encoding in texs])
return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks}
class RandomizeImageTransform(object):
"""Standardize image and randomly augment"""
def __init__(self, width, height, random_magnitude):
self.transform = T.Compose((
lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
contrast=random_magnitude / 10,
saturation=random_magnitude / 10,
hue=min(0.5, random_magnitude / 10)),
T.Resize(height, max_size=width),
T.Grayscale(),
T.functional.invert,
T.CenterCrop((height, width)),
torch.Tensor.contiguous,
lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
T.ConvertImageDtype(torch.float32)
))
def __call__(self, image):
image = self.transform(image)
return image
class ExtractEquationFromTexTransform(object):
"""Extracts ...\[ equation \]... from tex file"""
def __init__(self):
self.equation_pattern = re.compile(r'\\\[(?P<equation>.*)\\\]', flags=re.DOTALL)
self.spaces = re.compile(r' +')
def __call__(self, tex):
equation = self.equation_pattern.search(tex)
equation = equation.group('equation')
equation = equation.strip()
equation = self.spaces.sub(' ', equation)
return equation
def generate_tex_tokenizer(dataloader):
"""Returns a tokenizer trained on texs from given dataset"""
texs = list(tqdm.tqdm((batch['tex'] for batch in dataloader), "Training tokenizer", total=len(dataloader)))
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
single="[CLS] $A [SEP]",
special_tokens=[
("[CLS]", tokenizer.token_to_id("[CLS]")),
("[SEP]", tokenizer.token_to_id("[SEP]")),
]
)
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
return tokenizer
class LatexImageDataModule(pl.LightningDataModule):
def __init__(self, image_width, image_height, batch_size, random_magnitude):
super().__init__()
dataset = TexImageDataset(root_dir=DATA_DIR,
image_transform=RandomizeImageTransform(image_width, image_height,
random_magnitude),
tex_transform=ExtractEquationFromTexTransform())
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
self.batch_size = batch_size
self.save_hyperparameters()
def train_tokenizer(self):
tokenizer = generate_tex_tokenizer(DataLoader(self.train_dataset, batch_size=32, num_workers=16))
torch.save(tokenizer, TOKENIZER_PATH)
return tokenizer
def _shared_dataloader(self, dataset, **kwargs):
tex_tokenizer = torch.load(TOKENIZER_PATH)
collate_fn = BatchCollator(tex_tokenizer)
return DataLoader(dataset, batch_size=self.batch_size, collate_fn=collate_fn, pin_memory=PIN_MEMORY,
num_workers=NUM_DATALOADER_WORKERS, persistent_workers=PERSISTENT_WORKERS, **kwargs)
def train_dataloader(self):
return self._shared_dataloader(self.train_dataset, shuffle=True)
def val_dataloader(self):
return self._shared_dataloader(self.val_dataset)
def test_dataloader(self):
return self._shared_dataloader(self.test_dataset)
|