Spaces:
Runtime error
Runtime error
File size: 5,191 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine import get_file_backend, list_from_file
from mmengine.logging import MMLogger
from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset
from .categories import CUB_CATEGORIES
@DATASETS.register_module()
class CUB(BaseDataset):
"""The CUB-200-2011 Dataset.
Support the `CUB-200-2011 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
Comparing with the `CUB-200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset,
there are much more pictures in `CUB-200-2011`. After downloading and decompression, the dataset
directory structure is as follows.
CUB dataset directory: ::
CUB_200_2011
βββ images
β βββ class_x
β β βββ xx1.jpg
β β βββ xx2.jpg
β β βββ ...
β βββ class_y
β β βββ yy1.jpg
β β βββ yy2.jpg
β β βββ ...
β βββ ...
βββ images.txt
βββ image_class_labels.txt
βββ train_test_split.txt
βββ ....
Args:
data_root (str): The root directory for CUB-200-2011 dataset.
split (str, optional): The dataset split, supports "train" and "test".
Default to "train".
Examples:
>>> from mmpretrain.datasets import CUB
>>> train_dataset = CUB(data_root='data/CUB_200_2011', split='train')
>>> train_dataset
Dataset CUB
Number of samples: 5994
Number of categories: 200
Root of dataset: data/CUB_200_2011
>>> test_dataset = CUB(data_root='data/CUB_200_2011', split='test')
>>> test_dataset
Dataset CUB
Number of samples: 5794
Number of categories: 200
Root of dataset: data/CUB_200_2011
""" # noqa: E501
METAINFO = {'classes': CUB_CATEGORIES}
def __init__(self,
data_root: str,
split: str = 'train',
test_mode: bool = False,
**kwargs):
splits = ['train', 'test']
assert split in splits, \
f"The split must be one of {splits}, but get '{split}'"
self.split = split
# To handle the BC-breaking
if split == 'train' and test_mode:
logger = MMLogger.get_current_instance()
logger.warning('split="train" but test_mode=True. '
'The training set will be used.')
ann_file = 'images.txt'
data_prefix = 'images'
image_class_labels_file = 'image_class_labels.txt'
train_test_split_file = 'train_test_split.txt'
self.backend = get_file_backend(data_root, enable_singleton=True)
self.image_class_labels_file = self.backend.join_path(
data_root, image_class_labels_file)
self.train_test_split_file = self.backend.join_path(
data_root, train_test_split_file)
super(CUB, self).__init__(
ann_file=ann_file,
data_root=data_root,
data_prefix=data_prefix,
test_mode=test_mode,
**kwargs)
def _load_data_from_txt(self, filepath):
"""load data from CUB txt file, the every line of the file is idx and a
data item."""
pairs = list_from_file(filepath)
data_dict = dict()
for pair in pairs:
idx, data_item = pair.split()
# all the index starts from 1 in CUB files,
# here we need to '- 1' to let them start from 0.
data_dict[int(idx) - 1] = data_item
return data_dict
def load_data_list(self):
"""Load images and ground truth labels."""
sample_dict = self._load_data_from_txt(self.ann_file)
label_dict = self._load_data_from_txt(self.image_class_labels_file)
split_dict = self._load_data_from_txt(self.train_test_split_file)
assert sample_dict.keys() == label_dict.keys() == split_dict.keys(),\
f'sample_ids should be same in files {self.ann_file}, ' \
f'{self.image_class_labels_file} and {self.train_test_split_file}'
data_list = []
for sample_id in sample_dict.keys():
if split_dict[sample_id] == '1' and self.split == 'test':
# skip train samples when split='test'
continue
elif split_dict[sample_id] == '0' and self.split == 'train':
# skip test samples when split='train'
continue
img_path = self.backend.join_path(self.img_prefix,
sample_dict[sample_id])
gt_label = int(label_dict[sample_id]) - 1
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)
return data_list
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""
body = [
f'Root of dataset: \t{self.data_root}',
]
return body
|