import itertools import logging import torch.utils.data from detectron2.config import CfgNode, configurable from detectron2.data.build import ( build_batch_data_loader, load_proposals_into_dataset, trivial_batch_collator, ) from detectron2.data.catalog import DatasetCatalog from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.samplers import InferenceSampler, TrainingSampler from detectron2.utils.comm import get_world_size from torch.utils.data.sampler import Sampler from collections import defaultdict from typing import Optional from detectron2.utils import comm def _compute_num_images_per_worker(cfg: CfgNode): num_workers = get_world_size() images_per_batch = cfg.SOLVER.IMS_PER_BATCH assert ( images_per_batch % num_workers == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( images_per_batch, num_workers ) assert ( images_per_batch >= num_workers ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( images_per_batch, num_workers ) images_per_worker = images_per_batch // num_workers return images_per_worker def filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names): """ Filter out images with none annotations or only crowd annotations (i.e., images without non-crowd annotations). A common training-time preprocessing on COCO dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. Returns: list[dict]: the same format, but filtered. """ num_before = len(dataset_dicts) def valid(anns): for ann in anns: if isinstance(ann, list): for instance in ann: if instance.get("iscrowd", 0) == 0: return True else: if ann.get("iscrowd", 0) == 0: return True return False dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] num_after = len(dataset_dicts) logger = logging.getLogger(__name__) logger.info( "Removed {} images with no usable annotations. {} images left.".format( num_before - num_after, num_after ) ) return dataset_dicts def get_detection_dataset_dicts( dataset_names, filter_empty=True, proposal_files=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: dataset_names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `dataset_names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(dataset_names, str): dataset_names = [dataset_names] assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] for dataset_name, dicts in zip(dataset_names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(dataset_names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts, dataset_names) assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) return dataset_dicts def _train_loader_from_config(cfg, mapper, *, dataset=None, sampler=None): if dataset is None: dataset = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is None: sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) elif sampler_name == "ClassAwareSampler": sampler = ClassAwareSampler(dataset) return { "dataset": dataset, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, "num_workers": cfg.DATALOADER.NUM_WORKERS, "use_mixup": True } # TODO can allow dataset as an iterable or IterableDataset to make this function more general @configurable(from_config=_train_loader_from_config) def build_detection_train_loader( dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0, use_mixup=False ): """ Build a dataloader for object detection with some default features. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, which coordinates a random shuffle sequence across all workers. total_batch_size (int): total batch size across all workers. Batching simply puts data into a list. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: torch.utils.data.DataLoader: a dataloader. Each output from it is a ``list[mapped_element]`` of length ``total_batch_size / num_workers``, where ``mapped_element`` is produced by the ``mapper``. """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: if use_mixup: dataset = MapDatasetMixup(dataset, mapper) else: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def _test_loader_from_config(cfg, dataset_name, mapper=None): """ Uses the given `dataset_name` argument (instead of the names in cfg), because the standard practice is to evaluate each test set individually (not combining them). """ dataset = get_detection_dataset_dicts( [dataset_name], filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] ] if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, False) return {"dataset": dataset, "mapper": mapper, "num_workers": cfg.DATALOADER.NUM_WORKERS} @configurable(from_config=_test_loader_from_config) def build_detection_test_loader(dataset, *, mapper, num_workers=0): """ Similar to `build_detection_train_loader`, but uses a batch size of 1. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. num_workers (int): number of parallel data loading workers Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. Examples: :: data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test") """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) sampler = InferenceSampler(len(dataset)) # Always use 1 image per worker during inference since this is the # standard when reporting inference time in papers. # batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, sampler=sampler, drop_last=False, num_workers=num_workers, collate_fn=trivial_batch_collator, ) return data_loader class ClassAwareSampler(Sampler): def __init__(self, dataset_dicts, seed: Optional[int] = None): """ """ self._size = len(dataset_dicts) assert self._size > 0 if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.weights = self._get_class_balance_factor(dataset_dicts) def __iter__(self): start = self._rank yield from itertools.islice( self._infinite_indices(), start, None, self._world_size) def _infinite_indices(self): g = torch.Generator() g.manual_seed(self._seed) while True: ids = torch.multinomial( self.weights, self._size, generator=g, replacement=True) yield from ids def _get_class_balance_factor(self, dataset_dicts, l=1.): ret = [] category_freq = defaultdict(int) for dataset_dict in dataset_dicts: # For each image (without repeats) cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} for cat_id in cat_ids: category_freq[cat_id] += 1 for i, dataset_dict in enumerate(dataset_dicts): cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} ret.append(sum( [1. / (category_freq[cat_id] ** l) for cat_id in cat_ids])) return torch.tensor(ret).float()