from dataset.randaugment import RandomAugment from torch.utils.data import DataLoader from .vqa import vqa_dataset import torch from torch import nn from torchvision import transforms from PIL import Image def create_dataset(dataset, config, data_dir='/data/mshukor/data'): normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) pretrain_transform = transforms.Compose([ transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) train_transform = transforms.Compose([ transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) test_transform = transforms.Compose([ transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), transforms.ToTensor(), normalize, ]) if dataset=='vqa': train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list']) return train_dataset, vqa_test_dataset def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): loaders = [] for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): if is_train: shuffle = (sampler is None) drop_last = True else: shuffle = False drop_last = False loader = DataLoader( dataset, batch_size=bs, num_workers=n_worker, pin_memory=True, sampler=sampler, shuffle=shuffle, collate_fn=collate_fn, drop_last=drop_last, ) loaders.append(loader) return loaders def create_sampler(datasets, shuffles, num_tasks, global_rank): samplers = [] for dataset,shuffle in zip(datasets,shuffles): sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) samplers.append(sampler) return samplers