from data_preprocessing import TexImageDataset, RandomizeImageTransform, ExtractEquationFromTexTransform, \ generate_tex_tokenizer, BatchCollator import torch from torch.utils.data import DataLoader import tqdm if __name__ == '__main__': image_transform = RandomizeImageTransform() tex_transform = ExtractEquationFromTexTransform() dataset = TexImageDataset('data', image_transform=image_transform, tex_transform=tex_transform) dataset.subjoin_image_normalize_transform() train_dataset, test_dataset = torch.utils.data.random_split( dataset, [len(dataset) * 9 // 10, len(dataset) // 10] ) train_dataloader = DataLoader(train_dataset, batch_size=16, num_workers=16) texs = list(tqdm.tqdm(batch['tex'] for batch in train_dataloader)) tokenizer = generate_tex_tokenizer(texs) collate_fn = BatchCollator(tokenizer) train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=16, collate_fn=collate_fn) batch = next(iter(train_dataloader)) print(batch['texs'])