|
from dataset.randaugment import RandomAugment |
|
from torch.utils.data import DataLoader |
|
|
|
from .vqa import vqa_dataset |
|
|
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
|
|
from PIL import Image |
|
|
|
|
|
def create_dataset(dataset, config, data_dir='/data/mshukor/data'): |
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
pretrain_transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(config['image_res'],scale=(0.2, 1.0), interpolation=Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
train_transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(config['image_res'],scale=(0.5, 1.0), interpolation=Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
test_transform = transforms.Compose([ |
|
transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
if dataset=='vqa': |
|
train_dataset = vqa_dataset(config['train_file'], train_transform, config['vqa_root'], config['vg_root'], split='train') |
|
vqa_test_dataset = vqa_dataset(config['test_file'], test_transform, config['vqa_root'], config['vg_root'], split='test', answer_list=config['answer_list']) |
|
return train_dataset, vqa_test_dataset |
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): |
|
loaders = [] |
|
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): |
|
if is_train: |
|
shuffle = (sampler is None) |
|
drop_last = True |
|
else: |
|
shuffle = False |
|
drop_last = False |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
drop_last=drop_last, |
|
) |
|
loaders.append(loader) |
|
return loaders |
|
|
|
|
|
def create_sampler(datasets, shuffles, num_tasks, global_rank): |
|
samplers = [] |
|
for dataset,shuffle in zip(datasets,shuffles): |
|
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) |
|
samplers.append(sampler) |
|
return samplers |