import bisect
import math
from collections import defaultdict
import numpy as np
from mmcv.utils import print_log
from import ConcatDataset as _ConcatDataset
from .builder import DATASETS
from .coco import CocoDataset
class ConcatDataset(_ConcatDataset):
"""A wrapper of concatenated dataset.
Same as :obj:``, but
concat the group flag for image aspect ratio.
datasets (list[:obj:`Dataset`]): A list of datasets.
separate_eval (bool): Whether to evaluate the results
separately if it is used as validation dataset.
Defaults to True.
def __init__(self, datasets, separate_eval=True):
super(ConcatDataset, self).__init__(datasets)
self.CLASSES = datasets[0].CLASSES
self.separate_eval = separate_eval
if not separate_eval:
if any([isinstance(ds, CocoDataset) for ds in datasets]):
raise NotImplementedError(
'Evaluating concatenated CocoDataset as a whole is not'
' supported! Please set "separate_eval=True"')
elif len(set([type(ds) for ds in datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types')
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
self.flag = np.concatenate(flags)
def get_cat_ids(self, idx):
"""Get category ids of concatenated dataset by index.
idx (int): Index of data.
list[int]: All categories in the image of specified index.
if idx < 0:
if -idx > len(self):
raise ValueError(
'absolute value of index should not exceed dataset length')
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx].get_cat_ids(sample_idx)
def evaluate(self, results, logger=None, **kwargs):
"""Evaluate the results.
results (list[list | tuple]): Testing results of the dataset.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
dict[str: float]: AP results of the total dataset or each separate
dataset if `self.separate_eval=True`.
assert len(results) == self.cumulative_sizes[-1], \
('Dataset and results have different sizes: '
f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
# Check whether all the datasets support evaluation
for dataset in self.datasets:
assert hasattr(dataset, 'evaluate'), \
f'{type(dataset)} does not implement evaluate function'
if self.separate_eval:
dataset_idx = -1
total_eval_results = dict()
for size, dataset in zip(self.cumulative_sizes, self.datasets):
start_idx = 0 if dataset_idx == -1 else \
end_idx = self.cumulative_sizes[dataset_idx + 1]
results_per_dataset = results[start_idx:end_idx]
f'\nEvaluateing {dataset.ann_file} with '
f'{len(results_per_dataset)} images now',
eval_results_per_dataset = dataset.evaluate(
results_per_dataset, logger=logger, **kwargs)
dataset_idx += 1
for k, v in eval_results_per_dataset.items():
total_eval_results.update({f'{dataset_idx}_{k}': v})
return total_eval_results
elif any([isinstance(ds, CocoDataset) for ds in self.datasets]):
raise NotImplementedError(
'Evaluating concatenated CocoDataset as a whole is not'
' supported! Please set "separate_eval=True"')
elif len(set([type(ds) for ds in self.datasets])) != 1:
raise NotImplementedError(
'All the datasets should have same types')
original_data_infos = self.datasets[0].data_infos
self.datasets[0].data_infos = sum(
[dataset.data_infos for dataset in self.datasets], [])
eval_results = self.datasets[0].evaluate(
results, logger=logger, **kwargs)
self.datasets[0].data_infos = original_data_infos
return eval_results
class RepeatDataset(object):
"""A wrapper of repeated dataset.
The length of repeated dataset will be `times` larger than the original
dataset. This is useful when the data loading time is long but the dataset
is small. Using RepeatDataset can reduce the data loading time between
dataset (:obj:`Dataset`): The dataset to be repeated.
times (int): Repeat times.
def __init__(self, dataset, times):
self.dataset = dataset
self.times = times
self.CLASSES = dataset.CLASSES
if hasattr(self.dataset, 'flag'):
self.flag = np.tile(self.dataset.flag, times)
self._ori_len = len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx % self._ori_len]
def get_cat_ids(self, idx):
"""Get category ids of repeat dataset by index.
idx (int): Index of data.
list[int]: All categories in the image of specified index.
return self.dataset.get_cat_ids(idx % self._ori_len)
def __len__(self):
"""Length after repetition."""
return self.times * self._ori_len
# Modified from # noqa
class ClassBalancedDataset(object):
"""A wrapper of repeated dataset with repeat factor.
Suitable for training on class imbalanced datasets like LVIS. Following
the sampling strategy in the `paper <>`_,
in each epoch, an image may appear multiple times based on its
"repeat factor".
The repeat factor for an image is a function of the frequency the rarest
category labeled in that image. The "frequency of category c" in [0, 1]
is defined by the fraction of images in the training set (without repeats)
in which category c appears.
The dataset needs to instantiate :func:`self.get_cat_ids` to support
The repeat factor is computed as followed.
1. For each category c, compute the fraction # of images
that contain it: :math:`f(c)`
2. For each category c, compute the category-level repeat factor:
:math:`r(c) = max(1, sqrt(t/f(c)))`
3. For each image I, compute the image-level repeat factor:
:math:`r(I) = max_{c in I} r(c)`
dataset (:obj:`CustomDataset`): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is
repeated. For categories with ``f_c >= oversample_thr``, there is
no oversampling. For categories with ``f_c < oversample_thr``, the
degree of oversampling following the square-root inverse frequency
heuristic above.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes will not be oversampled. Otherwise, they will be categorized
as the pure background class and involved into the oversampling.
Default: True.
def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
self.dataset = dataset
self.oversample_thr = oversample_thr
self.filter_empty_gt = filter_empty_gt
self.CLASSES = dataset.CLASSES
repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
repeat_indices = []
for dataset_idx, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
self.repeat_indices = repeat_indices
flags = []
if hasattr(self.dataset, 'flag'):
for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
flags.extend([flag] * int(math.ceil(repeat_factor)))
assert len(flags) == len(repeat_indices)
self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr):
"""Get repeat factor for each images in the dataset.
dataset (:obj:`CustomDataset`): The dataset
repeat_thr (float): The threshold of frequency. If an image
contains the categories whose frequency below the threshold,
it would be repeated.
list[float]: The repeat factors for each images in the dataset.
# 1. For each category c, compute the fraction # of images
# that contain it: f(c)
category_freq = defaultdict(int)
num_images = len(dataset)
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
if len(cat_ids) == 0 and not self.filter_empty_gt:
cat_ids = set([len(self.CLASSES)])
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / num_images
# 2. For each category c, compute the category-level repeat factor:
# r(c) = max(1, sqrt(t/f(c)))
category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
# 3. For each image I, compute the image-level repeat factor:
# r(I) = max_{c in I} r(c)
repeat_factors = []
for idx in range(num_images):
cat_ids = set(self.dataset.get_cat_ids(idx))
if len(cat_ids) == 0 and not self.filter_empty_gt:
cat_ids = set([len(self.CLASSES)])
repeat_factor = 1
if len(cat_ids) > 0:
repeat_factor = max(
for cat_id in cat_ids})
return repeat_factors
def __getitem__(self, idx):
ori_index = self.repeat_indices[idx]
return self.dataset[ori_index]
def __len__(self):
"""Length after repetition."""
return len(self.repeat_indices)