|
import torch |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval |
|
from data.nocaps_dataset import nocaps_eval |
|
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval |
|
from data.vqa_dataset import vqa_dataset |
|
from data.nlvr_dataset import nlvr_dataset |
|
from data.pretrain_dataset import pretrain_dataset |
|
from transform.randaugment import RandomAugment |
|
|
|
def create_dataset(dataset, config, min_scale=0.5): |
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
transform_train = transforms.Compose([ |
|
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
transform_test = transforms.Compose([ |
|
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
|
|
if dataset=='pretrain': |
|
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) |
|
return dataset |
|
|
|
elif dataset=='caption_coco': |
|
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) |
|
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') |
|
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') |
|
return train_dataset, val_dataset, test_dataset |
|
|
|
elif dataset=='nocaps': |
|
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') |
|
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') |
|
return val_dataset, test_dataset |
|
|
|
elif dataset=='retrieval_coco': |
|
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) |
|
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') |
|
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') |
|
return train_dataset, val_dataset, test_dataset |
|
|
|
elif dataset=='retrieval_flickr': |
|
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) |
|
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') |
|
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') |
|
return train_dataset, val_dataset, test_dataset |
|
|
|
elif dataset=='vqa': |
|
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], |
|
train_files = config['train_files'], split='train') |
|
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') |
|
return train_dataset, test_dataset |
|
|
|
elif dataset=='nlvr': |
|
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') |
|
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') |
|
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') |
|
return train_dataset, val_dataset, test_dataset |
|
|
|
|
|
def create_sampler(datasets, shuffles, num_tasks, global_rank): |
|
samplers = [] |
|
for dataset,shuffle in zip(datasets,shuffles): |
|
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) |
|
samplers.append(sampler) |
|
return samplers |
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): |
|
loaders = [] |
|
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): |
|
if is_train: |
|
shuffle = (sampler is None) |
|
drop_last = True |
|
else: |
|
shuffle = False |
|
drop_last = False |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
drop_last=drop_last, |
|
) |
|
loaders.append(loader) |
|
return loaders |
|
|
|
|