KimingChen's picture
Upload 82 files
9204f12 verified
raw
history blame
4.03 kB
import os
import torch
from torchvision import datasets
class CUBDataset(datasets.ImageFolder):
"""
Wrapper for the CUB-200-2011 dataset.
Method DatasetBirds.__getitem__() returns tuple of image and its corresponding label.
Dataset per https://github.com/slipnitskaya/caltech-birds-advanced-classification
"""
def __init__(self,
root,
transform=None,
target_transform=None,
loader=datasets.folder.default_loader,
is_valid_file=None,
train=True,
bboxes=False):
img_root = os.path.join(root, 'images')
super(CUBDataset, self).__init__(
root=img_root,
transform=None,
target_transform=None,
loader=loader,
is_valid_file=is_valid_file,
)
self.redefine_class_to_idx()
self.transform_ = transform
self.target_transform_ = target_transform
self.train = train
# obtain sample ids filtered by split
path_to_splits = os.path.join(root, 'train_test_split.txt')
indices_to_use = list()
with open(path_to_splits, 'r') as in_file:
for line in in_file:
idx, use_train = line.strip('\n').split(' ', 2)
if bool(int(use_train)) == self.train:
indices_to_use.append(int(idx))
# obtain filenames of images
path_to_index = os.path.join(root, 'images.txt')
filenames_to_use = list()
with open(path_to_index, 'r') as in_file:
for line in in_file:
idx, fn = line.strip('\n').split(' ', 2)
if fn not in filenames_to_use and int(idx) in indices_to_use:
filenames_to_use.append(fn)
img_paths_cut = {'/'.join(img_path.rsplit('/', 2)[-2:]): idx for idx, (img_path, lb) in enumerate(self.imgs)}
imgs_to_use = [self.imgs[img_paths_cut[fn]] for fn in filenames_to_use]
_, targets_to_use = list(zip(*imgs_to_use))
self.imgs = self.samples = imgs_to_use
self.targets = targets_to_use
if bboxes:
# get coordinates of a bounding box
path_to_bboxes = os.path.join(root, 'bounding_boxes.txt')
bounding_boxes = list()
with open(path_to_bboxes, 'r') as in_file:
for line in in_file:
idx, x, y, w, h = map(lambda x: float(x), line.strip('\n').split(' '))
if int(idx) in indices_to_use:
bounding_boxes.append((x, y, w, h))
self.bboxes = bounding_boxes
else:
self.bboxes = None
def __getitem__(self, index):
# generate one sample
sample, target = super(CUBDataset, self).__getitem__(index)
if self.bboxes is not None:
# squeeze coordinates of the bounding box to range [0, 1]
width, height = sample.width, sample.height
x, y, w, h = self.bboxes[index]
scale_resize = 500 / width
scale_resize_crop = scale_resize * (375 / 500)
x_rel = scale_resize_crop * x / 375
y_rel = scale_resize_crop * y / 375
w_rel = scale_resize_crop * w / 375
h_rel = scale_resize_crop * h / 375
target = torch.tensor([target, x_rel, y_rel, w_rel, h_rel])
if self.transform_ is not None:
sample = self.transform_(sample)
if self.target_transform_ is not None:
target = self.target_transform_(target)
return sample, target
def redefine_class_to_idx(self):
adjusted_dict = {}
for k, v in self.class_to_idx.items():
k = k.split('.')[-1].replace('_', ' ')
split_key = k.split(' ')
if len(split_key) > 2:
k = '-'.join(split_key[:-1]) + " " + split_key[-1]
adjusted_dict[k] = v
self.class_to_idx = adjusted_dict