doevent commited on
Commit
b6c3e81
1 Parent(s): 5d3ac8c

Upload data/flickr30k_dataset.py

Browse files
Files changed (1) hide show
  1. data/flickr30k_dataset.py +93 -0
data/flickr30k_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class flickr30k_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. flickr30k/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
18
+ filename = 'flickr30k_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class flickr30k_retrieval_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. flickr30k/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
61
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+
89
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
90
+ image = Image.open(image_path).convert('RGB')
91
+ image = self.transform(image)
92
+
93
+ return image, index