doevent commited on
Commit
5d3ac8c
1 Parent(s): da5787d

Upload data/coco_karpathy_dataset.py

Browse files
Files changed (1) hide show
  1. data/coco_karpathy_dataset.py +126 -0
data/coco_karpathy_dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class coco_karpathy_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. coco/images/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
18
+ filename = 'coco_karpathy_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class coco_karpathy_caption_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. coco/images/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
61
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ def __len__(self):
70
+ return len(self.annotation)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.image_root,ann['image'])
77
+ image = Image.open(image_path).convert('RGB')
78
+ image = self.transform(image)
79
+
80
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
81
+
82
+ return image, int(img_id)
83
+
84
+
85
+ class coco_karpathy_retrieval_eval(Dataset):
86
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
87
+ '''
88
+ image_root (string): Root directory of images (e.g. coco/images/)
89
+ ann_root (string): directory to store the annotation file
90
+ split (string): val or test
91
+ '''
92
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
93
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
94
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
95
+
96
+ download_url(urls[split],ann_root)
97
+
98
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
99
+ self.transform = transform
100
+ self.image_root = image_root
101
+
102
+ self.text = []
103
+ self.image = []
104
+ self.txt2img = {}
105
+ self.img2txt = {}
106
+
107
+ txt_id = 0
108
+ for img_id, ann in enumerate(self.annotation):
109
+ self.image.append(ann['image'])
110
+ self.img2txt[img_id] = []
111
+ for i, caption in enumerate(ann['caption']):
112
+ self.text.append(pre_caption(caption,max_words))
113
+ self.img2txt[img_id].append(txt_id)
114
+ self.txt2img[txt_id] = img_id
115
+ txt_id += 1
116
+
117
+ def __len__(self):
118
+ return len(self.annotation)
119
+
120
+ def __getitem__(self, index):
121
+
122
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
123
+ image = Image.open(image_path).convert('RGB')
124
+ image = self.transform(image)
125
+
126
+ return image, index