Spaces:
Runtime error
Runtime error
import einops | |
import os | |
import tokenizers | |
import torch | |
import torchvision | |
import torchvision.transforms as T | |
from torch.utils.data import Dataset | |
import tqdm | |
import re | |
class TexImageDataset(Dataset): | |
"""Image and tex dataset.""" | |
def __init__(self, root_dir, image_transform=None, tex_transform=None): | |
""" | |
Args: | |
root_dir (string): Directory with all the images and tex files. | |
image_transform: callable image preprocessing | |
tex_transform: callable tex preprocessing | |
""" | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
self.root_dir = root_dir | |
self.filenames = sorted(set( | |
os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png') | |
)) | |
self.image_transform = image_transform | |
self.tex_transform = tex_transform | |
self.tex_tokenizer = None | |
self.texs = [] | |
for filename in tqdm.tqdm(self.filenames, "Preloading tex files"): | |
tex_path = os.path.join(self.root_dir, filename + '.tex') | |
with open(tex_path) as file: | |
tex = file.read() | |
if self.tex_transform: | |
tex = self.tex_transform(tex) | |
self.texs.append(tex) | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, idx): | |
filename = self.filenames[idx] | |
image_path = os.path.join(self.root_dir, filename + '.png') | |
image = torchvision.io.read_image(image_path) | |
if self.image_transform: | |
image = self.image_transform(image) | |
tex = self.texs[idx] | |
return {"image": image, "tex": tex} | |
def subjoin_image_normalize_transform(self): | |
"""Appends a normalize layer with mean and std computed after iterating over dataset""" | |
mean = 0 | |
std = 0 | |
for item in tqdm.tqdm(self): | |
image = item['image'] | |
mean += image.mean() | |
std += image.std() | |
mean /= len(self) | |
std /= len(self) | |
normalize = T.Normalize(mean, std) | |
if self.image_transform: | |
self.image_transform = T.Compose((self.image_transform, normalize)) | |
else: | |
self.image_transform = normalize | |
def subjoin_tex_tokenize_transform(self, texs, vocab_size=300): | |
"""Returns a tokenizer trained on given tex strings""" | |
# os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]")) | |
tokenizer_trainer = tokenizers.trainers.BpeTrainer( | |
vocab_size=vocab_size, | |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] | |
) | |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace() | |
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer) | |
tokenizer.post_processor = tokenizers.processors.TemplateProcessing( | |
single="$A [SEP]", | |
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))] | |
) | |
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]") | |
self.tokenizer = tokenizer | |
return tokenizer | |
class BatchCollator(object): | |
"""Image, tex batch collator""" | |
def __init__(self, tokenizer): | |
self.tokenizer = tokenizer | |
def __call__(self, batch): | |
images = [i['image'] for i in batch] | |
images = einops.rearrange(images, 'b c h w -> b c h w') | |
texs = [item['tex'] for item in batch] | |
texs = self.tokenizer.encode_batch(texs) | |
tex_ids = torch.Tensor([encoding.ids for encoding in texs]) | |
attention_masks = torch.Tensor([encoding.attention_mask for encoding in texs]) | |
return {'images': images, 'tex_ids': tex_ids, 'tex_attention_masks': attention_masks} | |
class StandardizeImageTransform(object): | |
"""Pad and crop image to a given size, grayscale and invert""" | |
def __init__(self, width=1024, height=128): | |
self.standardize = T.Compose(( | |
T.Resize(height), | |
T.Grayscale(), | |
T.functional.invert, | |
T.CenterCrop((height, width)), | |
T.ConvertImageDtype(torch.float32) | |
)) | |
def __call__(self, image): | |
image = self.standardize(image) | |
return image | |
class RandomizeImageTransform(object): | |
"""Standardize image and randomly augment""" | |
def __init__(self, width=1024, height=128, random_magnitude=5): | |
self.transform = T.Compose(( | |
T.ColorJitter(brightness=random_magnitude / 10), | |
T.Resize(height), | |
T.Grayscale(), | |
T.functional.invert, | |
T.CenterCrop((height, width)), | |
torch.Tensor.contiguous, | |
T.RandAugment(magnitude=random_magnitude), | |
T.ConvertImageDtype(torch.float32) | |
)) | |
def __call__(self, image): | |
image = self.transform(image) | |
return image | |
class ExtractEquationFromTexTransform(object): | |
"""Extracts ...\[ equation \]... from tex file""" | |
def __init__(self): | |
self.equation_pattern = re.compile(r'\\\[(?P<equation>.*)\\\]', flags=re.DOTALL) | |
self.spaces = re.compile(r' +') | |
def __call__(self, tex): | |
equation = self.equation_pattern.search(tex) | |
equation = equation.group('equation') | |
equation = equation.strip() | |
equation = self.spaces.sub(' ', equation) | |
return equation | |