Spaces:
Running
Running
Upload data/__init__.py
Browse files- data/__init__.py +101 -0
data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.transforms.functional import InterpolationMode
|
5 |
+
|
6 |
+
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
7 |
+
from data.nocaps_dataset import nocaps_eval
|
8 |
+
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
9 |
+
from data.vqa_dataset import vqa_dataset
|
10 |
+
from data.nlvr_dataset import nlvr_dataset
|
11 |
+
from data.pretrain_dataset import pretrain_dataset
|
12 |
+
from transform.randaugment import RandomAugment
|
13 |
+
|
14 |
+
def create_dataset(dataset, config, min_scale=0.5):
|
15 |
+
|
16 |
+
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
17 |
+
|
18 |
+
transform_train = transforms.Compose([
|
19 |
+
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
20 |
+
transforms.RandomHorizontalFlip(),
|
21 |
+
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
22 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
normalize,
|
25 |
+
])
|
26 |
+
transform_test = transforms.Compose([
|
27 |
+
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
normalize,
|
30 |
+
])
|
31 |
+
|
32 |
+
if dataset=='pretrain':
|
33 |
+
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
elif dataset=='caption_coco':
|
37 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
38 |
+
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
39 |
+
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
40 |
+
return train_dataset, val_dataset, test_dataset
|
41 |
+
|
42 |
+
elif dataset=='nocaps':
|
43 |
+
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
44 |
+
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
45 |
+
return val_dataset, test_dataset
|
46 |
+
|
47 |
+
elif dataset=='retrieval_coco':
|
48 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
49 |
+
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
50 |
+
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
51 |
+
return train_dataset, val_dataset, test_dataset
|
52 |
+
|
53 |
+
elif dataset=='retrieval_flickr':
|
54 |
+
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
55 |
+
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
56 |
+
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
57 |
+
return train_dataset, val_dataset, test_dataset
|
58 |
+
|
59 |
+
elif dataset=='vqa':
|
60 |
+
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
61 |
+
train_files = config['train_files'], split='train')
|
62 |
+
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
63 |
+
return train_dataset, test_dataset
|
64 |
+
|
65 |
+
elif dataset=='nlvr':
|
66 |
+
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
67 |
+
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
68 |
+
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
69 |
+
return train_dataset, val_dataset, test_dataset
|
70 |
+
|
71 |
+
|
72 |
+
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
73 |
+
samplers = []
|
74 |
+
for dataset,shuffle in zip(datasets,shuffles):
|
75 |
+
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
76 |
+
samplers.append(sampler)
|
77 |
+
return samplers
|
78 |
+
|
79 |
+
|
80 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
81 |
+
loaders = []
|
82 |
+
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
83 |
+
if is_train:
|
84 |
+
shuffle = (sampler is None)
|
85 |
+
drop_last = True
|
86 |
+
else:
|
87 |
+
shuffle = False
|
88 |
+
drop_last = False
|
89 |
+
loader = DataLoader(
|
90 |
+
dataset,
|
91 |
+
batch_size=bs,
|
92 |
+
num_workers=n_worker,
|
93 |
+
pin_memory=True,
|
94 |
+
sampler=sampler,
|
95 |
+
shuffle=shuffle,
|
96 |
+
collate_fn=collate_fn,
|
97 |
+
drop_last=drop_last,
|
98 |
+
)
|
99 |
+
loaders.append(loader)
|
100 |
+
return loaders
|
101 |
+
|