Spaces:
Runtime error
Runtime error
import os | |
import re | |
import tokenizers | |
import torch | |
import torchvision | |
import torchvision.transforms as T | |
import tqdm | |
import PIL | |
from torch.utils.data import Dataset, DataLoader | |
directory = "/external2/dkkoshman/repos/ML2TransformerApp/data/" | |
class TexImageDataset(Dataset): | |
"""Image to tex dataset.""" | |
def __init__(self, root_dir, image_preprocessing=None, tex_preprocessing=None): | |
""" | |
Args: | |
root_dir (string): Directory with all the images and tex files. | |
transform (callable, optional): Optional transform to be applied | |
on a sample. | |
image_preprocessing: callable image preprocessing | |
tex_preprocessing: callable tex preprocessing | |
""" | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
self.root_dir = root_dir | |
filenames = sorted( | |
set(os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png')) | |
) | |
self.data = [] | |
for filename in tqdm.tqdm(filenames): | |
tex_path = self.root_dir + filename + '.tex' | |
image_path = self.root_dir + filename + '.png' | |
with open(tex_path) as file: | |
tex = file.read() | |
if tex_preprocessing: | |
tex = tex_preprocessing(tex) | |
image = torchvision.io.read_image(image_path) | |
if image_preprocessing: | |
image = image_preprocessing(image) | |
self.data.append((image, tex)) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
image, tex = self.data[idx] | |
return {"image": image, "tex": tex} | |
class StandardizeImage(object): | |
"""Pad and crop image to a given size, invert and normalize""" | |
def __init__(self, width=1024, height=128): | |
self.transform = T.Compose(( | |
T.Resize(height), | |
T.Grayscale(), | |
T.functional.invert, | |
T.CenterCrop((height, width)) | |
)) | |
def __call__(self, image): | |
image = self.transform(image) | |
return image | |
class RandomTransformImage(object): | |
"""Standardize image and randomly augment""" | |
def __init__(self, standardize, random_magnitude=5): | |
self.brighten = T.ColorJitter(brightness=(1/random_magnitude, 1 + 1/random_magnitude)) | |
self.standardize = standardize | |
self.rand_aug = T.RandAugment(magnitude=random_magnitude) | |
def __call__(self, image): | |
image = self.brighten(image) | |
image = self.standardize(image) | |
image = image.contiguous() | |
image = self.rand_aug(image) | |
return image | |
def generate_tex_tokenizer(texs): | |
"""Returns a tokeniser trained on tex strings from dataset""" | |
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]")) | |
tokenizer_trainer = tokenizers.trainers.BpeTrainer( | |
vocab_size=300, | |
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"] | |
) | |
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace() | |
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer) | |
tokenizer.post_processor = tokenizers.processors.TemplateProcessing( | |
single="$A [SEP]", | |
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))] | |
) | |
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]") | |
return tokenizer | |