Spaces:
Runtime error
Runtime error
File size: 3,520 Bytes
e33424f 41c9661 e33424f 41c9661 e33424f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
import re
import tokenizers
import torch
import torchvision
import torchvision.transforms as T
import tqdm
import PIL
from torch.utils.data import Dataset, DataLoader
directory = "/external2/dkkoshman/repos/ML2TransformerApp/data/"
class TexImageDataset(Dataset):
"""Image to tex dataset."""
def __init__(self, root_dir, image_preprocessing=None, tex_preprocessing=None):
"""
Args:
root_dir (string): Directory with all the images and tex files.
transform (callable, optional): Optional transform to be applied
on a sample.
image_preprocessing: callable image preprocessing
tex_preprocessing: callable tex preprocessing
"""
torch.multiprocessing.set_sharing_strategy('file_system')
self.root_dir = root_dir
filenames = sorted(
set(os.path.splitext(filename)[0] for filename in os.listdir(root_dir) if filename.endswith('png'))
)
self.data = []
for filename in tqdm.tqdm(filenames):
tex_path = self.root_dir + filename + '.tex'
image_path = self.root_dir + filename + '.png'
with open(tex_path) as file:
tex = file.read()
if tex_preprocessing:
tex = tex_preprocessing(tex)
image = torchvision.io.read_image(image_path)
if image_preprocessing:
image = image_preprocessing(image)
self.data.append((image, tex))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image, tex = self.data[idx]
return {"image": image, "tex": tex}
class StandardizeImage(object):
"""Pad and crop image to a given size, invert and normalize"""
def __init__(self, width=1024, height=128):
self.transform = T.Compose((
T.Resize(height),
T.Grayscale(),
T.functional.invert,
T.CenterCrop((height, width))
))
def __call__(self, image):
image = self.transform(image)
return image
class RandomTransformImage(object):
"""Standardize image and randomly augment"""
def __init__(self, standardize, random_magnitude=5):
self.brighten = T.ColorJitter(brightness=(1/random_magnitude, 1 + 1/random_magnitude))
self.standardize = standardize
self.rand_aug = T.RandAugment(magnitude=random_magnitude)
def __call__(self, image):
image = self.brighten(image)
image = self.standardize(image)
image = image.contiguous()
image = self.rand_aug(image)
return image
def generate_tex_tokenizer(texs):
"""Returns a tokeniser trained on tex strings from dataset"""
tokenizer = tokenizers.Tokenizer(tokenizers.models.BPE(unk_token="[UNK]"))
tokenizer_trainer = tokenizers.trainers.BpeTrainer(
vocab_size=300,
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
)
tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.Whitespace()
tokenizer.train_from_iterator(texs, trainer=tokenizer_trainer)
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
single="$A [SEP]",
special_tokens=[("[SEP]", tokenizer.token_to_id("[SEP]"))]
)
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("[PAD]"), pad_token="[PAD]")
return tokenizer
|