# Copyright (c) OpenMMLab. All rights reserved. import copy import numpy as np from mmengine.dataset import BaseDataset, force_full_init from mmpretrain.registry import DATASETS @DATASETS.register_module() class KFoldDataset: """A wrapper of dataset for K-Fold cross-validation. K-Fold cross-validation divides all the samples in groups of samples, called folds, of almost equal sizes. And we use k-1 of folds to do training and use the fold left to do validation. Args: dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be divided fold (int): The fold used to do validation. Defaults to 0. num_splits (int): The number of all folds. Defaults to 5. test_mode (bool): Use the training dataset or validation dataset. Defaults to False. seed (int, optional): The seed to shuffle the dataset before splitting. If None, not shuffle the dataset. Defaults to None. """ def __init__(self, dataset, fold=0, num_splits=5, test_mode=False, seed=None): if isinstance(dataset, dict): self.dataset = DATASETS.build(dataset) # Init the dataset wrapper lazily according to the dataset setting. lazy_init = dataset.get('lazy_init', False) elif isinstance(dataset, BaseDataset): self.dataset = dataset else: raise TypeError(f'Unsupported dataset type {type(dataset)}.') self._metainfo = getattr(self.dataset, 'metainfo', {}) self.fold = fold self.num_splits = num_splits self.test_mode = test_mode self.seed = seed self._fully_initialized = False if not lazy_init: self.full_init() @property def metainfo(self) -> dict: """Get the meta information of ``self.dataset``. Returns: dict: Meta information of the dataset. """ # Prevent `self._metainfo` from being modified by outside. return copy.deepcopy(self._metainfo) def full_init(self): """fully initialize the dataset.""" if self._fully_initialized: return self.dataset.full_init() ori_len = len(self.dataset) indices = list(range(ori_len)) if self.seed is not None: rng = np.random.default_rng(self.seed) rng.shuffle(indices) test_start = ori_len * self.fold // self.num_splits test_end = ori_len * (self.fold + 1) // self.num_splits if self.test_mode: indices = indices[test_start:test_end] else: indices = indices[:test_start] + indices[test_end:] self._ori_indices = indices self.dataset = self.dataset.get_subset(indices) self._fully_initialized = True @force_full_init def _get_ori_dataset_idx(self, idx: int) -> int: """Convert global idx to local index. Args: idx (int): Global index of ``KFoldDataset``. Returns: int: The original index in the whole dataset. """ return self._ori_indices[idx] @force_full_init def get_data_info(self, idx: int) -> dict: """Get annotation by index. Args: idx (int): Global index of ``KFoldDataset``. Returns: dict: The idx-th annotation of the datasets. """ return self.dataset.get_data_info(idx) @force_full_init def __len__(self): return len(self.dataset) @force_full_init def __getitem__(self, idx): return self.dataset[idx] @force_full_init def get_cat_ids(self, idx): return self.dataset.get_cat_ids(idx) @force_full_init def get_gt_labels(self): return self.dataset.get_gt_labels() @property def CLASSES(self): """Return all categories names.""" return self._metainfo.get('classes', None) @property def class_to_idx(self): """Map mapping class name to class index. Returns: dict: mapping from class name to class index. """ return {cat: i for i, cat in enumerate(self.CLASSES)} def __repr__(self): """Print the basic information of the dataset. Returns: str: Formatted string. """ head = 'Dataset ' + self.__class__.__name__ body = [] type_ = 'test' if self.test_mode else 'training' body.append(f'Type: \t{type_}') body.append(f'Seed: \t{self.seed}') def ordinal(n): # Copy from https://codegolf.stackexchange.com/a/74047 suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4] return f'{n}{suffix}' body.append( f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold') if self._fully_initialized: body.append(f'Number of samples: \t{self.__len__()}') else: body.append("Haven't been initialized") if self.CLASSES is not None: body.append(f'Number of categories: \t{len(self.CLASSES)}') else: body.append('The `CLASSES` meta info is not set.') body.append( f'Original dataset type:\t{self.dataset.__class__.__name__}') lines = [head] + [' ' * 4 + line for line in body] return '\n'.join(lines)