# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved import itertools import logging import numpy as np from collections import Counter import torch.utils.data from tabulate import tabulate from termcolor import colored from detectron2.utils.logger import _log_api_usage, log_first_n from detectron2.data.catalog import DatasetCatalog, MetadataCatalog import torch.utils.data from detectron2.config import configurable from detectron2.data.build import ( build_batch_data_loader, trivial_batch_collator, load_proposals_into_dataset, filter_images_with_only_crowd_annotations, filter_images_with_few_keypoints, print_instances_class_histogram, ) from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper from detectron2.data.detection_utils import check_metadata_consistency from detectron2.data.samplers import ( InferenceSampler, RandomSubsetTrainingSampler, RepeatFactorTrainingSampler, TrainingSampler, ) """ This file contains the default logic to build a dataloader for training or testing. """ __all__ = [ "build_detection_train_loader", "build_detection_test_loader", ] def print_classification_instances_class_histogram(dataset_dicts, class_names): """ Args: dataset_dicts (list[dict]): list of dataset dicts. class_names (list[str]): list of class names (zero-indexed). """ num_classes = len(class_names) hist_bins = np.arange(num_classes + 1) histogram = np.zeros((num_classes,), dtype=np.int) for entry in dataset_dicts: classes = np.asarray([entry["category_id"]], dtype=np.int) if len(classes): assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" assert ( classes.max() < num_classes ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" histogram += np.histogram(classes, bins=hist_bins)[0] N_COLS = min(6, len(class_names) * 2) def short_name(x): # make long class names shorter. useful for lvis if len(x) > 13: return x[:11] + ".." return x data = list( itertools.chain( *[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)] ) ) total_num_instances = sum(data[1::2]) data.extend([None] * (N_COLS - (len(data) % N_COLS))) if num_classes > 1: data.extend(["total", total_num_instances]) data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) table = tabulate( data, headers=["category", "#instances"] * (N_COLS // 2), tablefmt="pipe", numalign="left", stralign="center", ) log_first_n( logging.INFO, "Distribution of instances among all {} categories:\n".format(num_classes) + colored(table, "cyan"), key="message", ) def wrap_metas(dataset_dict, **kwargs): def _assign_attr(data_dict: dict, **kwargs): assert not any( [key in data_dict for key in kwargs] ), "Assigned attributes should not exist in the original sample." data_dict.update(kwargs) return data_dict return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict] def get_detection_dataset_dicts( names, filter_empty=True, min_keypoints=0, proposal_files=None ): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations min_keypoints (int): filter out images with fewer keypoints than `min_keypoints`. Set to 0 to do nothing. proposal_files (list[str]): if given, a list of object proposal files that match each dataset in `names`. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(names, str): names = [names] assert len(names), names dataset_dicts = [ wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name) for dataset_name in names ] for dataset_name, dicts in zip(names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) if proposal_files is not None: assert len(names) == len(proposal_files) # load precomputed proposals from proposal files dataset_dicts = [ load_proposals_into_dataset(dataset_i_dicts, proposal_file) for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) ] dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) if min_keypoints > 0 and has_instances: dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) if has_instances: try: class_names = MetadataCatalog.get(names[0]).thing_classes check_metadata_consistency("thing_classes", names) print_instances_class_histogram(dataset_dicts, class_names) except AttributeError: # class names are not available for this dataset pass assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) return dataset_dicts def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): if dataset is None: dataset = get_detection_dataset_dicts( cfg.DATASETS.TRAIN, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE if cfg.MODEL.KEYPOINT_ON else 0, proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, ) _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) if mapper is None: mapper = DatasetMapper(cfg, True) if sampler is None: sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": repeat_factors = ( RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( dataset, cfg.DATALOADER.REPEAT_THRESHOLD ) ) sampler = RepeatFactorTrainingSampler(repeat_factors) elif sampler_name == "RandomSubsetTrainingSampler": sampler = RandomSubsetTrainingSampler( len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO ) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) return { "dataset": dataset, "sampler": sampler, "mapper": mapper, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, "num_workers": cfg.DATALOADER.NUM_WORKERS, } # TODO can allow dataset as an iterable or IterableDataset to make this function more general @configurable(from_config=_train_loader_from_config) def build_detection_train_loader( dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0, ): """ Build a dataloader for object detection with some default features. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, which coordinates an infinite random shuffle sequence across all workers. total_batch_size (int): total batch size across all workers. Batching simply puts data into a list. aspect_ratio_grouping (bool): whether to group images with similar aspect ratio for efficiency. When enabled, it requires each element in dataset be a dict with keys "width" and "height". num_workers (int): number of parallel data loading workers Returns: torch.utils.data.DataLoader: a dataloader. Each output from it is a ``list[mapped_element]`` of length ``total_batch_size / num_workers``, where ``mapped_element`` is produced by the ``mapper``. """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = TrainingSampler(len(dataset)) assert isinstance(sampler, torch.utils.data.sampler.Sampler) return build_batch_data_loader( dataset, sampler, total_batch_size, aspect_ratio_grouping=aspect_ratio_grouping, num_workers=num_workers, ) def _test_loader_from_config(cfg, dataset_name, mapper=None): """ Uses the given `dataset_name` argument (instead of the names in cfg), because the standard practice is to evaluate each test set individually (not combining them). """ if isinstance(dataset_name, str): dataset_name = [dataset_name] dataset = get_detection_dataset_dicts( dataset_name, filter_empty=False, proposal_files=[ cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name ] if cfg.MODEL.LOAD_PROPOSALS else None, ) if mapper is None: mapper = DatasetMapper(cfg, False) return { "dataset": dataset, "mapper": mapper, "num_workers": 0, "samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH, } @configurable(from_config=_test_loader_from_config) def build_detection_test_loader( dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1 ): """ Similar to `build_detection_train_loader`, but uses a batch size of 1, and :class:`InferenceSampler`. This sampler coordinates all workers to produce the exact set of all samples. This interface is experimental. Args: dataset (list or torch.utils.data.Dataset): a list of dataset dicts, or a map-style pytorch dataset. They can be obtained by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, which splits the dataset across all workers. num_workers (int): number of parallel data loading workers Returns: DataLoader: a torch DataLoader, that loads the given detection dataset, with test-time transformation and batching. Examples: :: data_loader = build_detection_test_loader( DatasetRegistry.get("my_test"), mapper=DatasetMapper(...)) # or, instantiate with a CfgNode: data_loader = build_detection_test_loader(cfg, "my_test") """ if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: dataset = MapDataset(dataset, mapper) if sampler is None: sampler = InferenceSampler(len(dataset)) # Always use 1 image per worker during inference since this is the # standard when reporting inference time in papers. batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, samples_per_gpu, drop_last=False ) data_loader = torch.utils.data.DataLoader( dataset, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, ) return data_loader