|
import itertools |
|
import logging |
|
import numpy as np |
|
from collections import Counter |
|
import torch.utils.data |
|
from tabulate import tabulate |
|
from termcolor import colored |
|
|
|
from detectron2.utils.logger import _log_api_usage, log_first_n |
|
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog |
|
import torch.utils.data |
|
from detectron2.config import configurable |
|
from detectron2.data.build import ( |
|
build_batch_data_loader, |
|
trivial_batch_collator, |
|
load_proposals_into_dataset, |
|
filter_images_with_only_crowd_annotations, |
|
filter_images_with_few_keypoints, |
|
print_instances_class_histogram, |
|
) |
|
|
|
from detectron2.data.common import DatasetFromList, MapDataset |
|
from detectron2.data.dataset_mapper import DatasetMapper |
|
from detectron2.data.detection_utils import check_metadata_consistency |
|
from detectron2.data.samplers import ( |
|
InferenceSampler, |
|
RandomSubsetTrainingSampler, |
|
RepeatFactorTrainingSampler, |
|
TrainingSampler, |
|
) |
|
|
|
""" |
|
This file contains the default logic to build a dataloader for training or testing. |
|
""" |
|
|
|
__all__ = [ |
|
"build_detection_train_loader", |
|
"build_detection_test_loader", |
|
] |
|
|
|
|
|
def print_classification_instances_class_histogram(dataset_dicts, class_names): |
|
""" |
|
Args: |
|
dataset_dicts (list[dict]): list of dataset dicts. |
|
class_names (list[str]): list of class names (zero-indexed). |
|
""" |
|
num_classes = len(class_names) |
|
hist_bins = np.arange(num_classes + 1) |
|
histogram = np.zeros((num_classes,), dtype=np.int) |
|
for entry in dataset_dicts: |
|
classes = np.asarray([entry["category_id"]], dtype=np.int) |
|
if len(classes): |
|
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" |
|
assert ( |
|
classes.max() < num_classes |
|
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" |
|
histogram += np.histogram(classes, bins=hist_bins)[0] |
|
|
|
N_COLS = min(6, len(class_names) * 2) |
|
|
|
def short_name(x): |
|
|
|
if len(x) > 13: |
|
return x[:11] + ".." |
|
return x |
|
|
|
data = list( |
|
itertools.chain( |
|
*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)] |
|
) |
|
) |
|
total_num_instances = sum(data[1::2]) |
|
data.extend([None] * (N_COLS - (len(data) % N_COLS))) |
|
if num_classes > 1: |
|
data.extend(["total", total_num_instances]) |
|
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) |
|
table = tabulate( |
|
data, |
|
headers=["category", "#instances"] * (N_COLS // 2), |
|
tablefmt="pipe", |
|
numalign="left", |
|
stralign="center", |
|
) |
|
log_first_n( |
|
logging.INFO, |
|
"Distribution of instances among all {} categories:\n".format(num_classes) |
|
+ colored(table, "cyan"), |
|
key="message", |
|
) |
|
|
|
|
|
def wrap_metas(dataset_dict, **kwargs): |
|
def _assign_attr(data_dict: dict, **kwargs): |
|
assert not any( |
|
[key in data_dict for key in kwargs] |
|
), "Assigned attributes should not exist in the original sample." |
|
data_dict.update(kwargs) |
|
return data_dict |
|
|
|
return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict] |
|
|
|
|
|
def get_detection_dataset_dicts( |
|
names, filter_empty=True, min_keypoints=0, proposal_files=None |
|
): |
|
""" |
|
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. |
|
|
|
Args: |
|
names (str or list[str]): a dataset name or a list of dataset names |
|
filter_empty (bool): whether to filter out images without instance annotations |
|
min_keypoints (int): filter out images with fewer keypoints than |
|
`min_keypoints`. Set to 0 to do nothing. |
|
proposal_files (list[str]): if given, a list of object proposal files |
|
that match each dataset in `names`. |
|
|
|
Returns: |
|
list[dict]: a list of dicts following the standard dataset dict format. |
|
""" |
|
if isinstance(names, str): |
|
names = [names] |
|
assert len(names), names |
|
dataset_dicts = [ |
|
wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name) |
|
for dataset_name in names |
|
] |
|
for dataset_name, dicts in zip(names, dataset_dicts): |
|
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) |
|
|
|
if proposal_files is not None: |
|
assert len(names) == len(proposal_files) |
|
|
|
dataset_dicts = [ |
|
load_proposals_into_dataset(dataset_i_dicts, proposal_file) |
|
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) |
|
] |
|
|
|
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) |
|
|
|
has_instances = "annotations" in dataset_dicts[0] |
|
if filter_empty and has_instances: |
|
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) |
|
if min_keypoints > 0 and has_instances: |
|
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) |
|
|
|
if has_instances: |
|
try: |
|
class_names = MetadataCatalog.get(names[0]).thing_classes |
|
check_metadata_consistency("thing_classes", names) |
|
print_instances_class_histogram(dataset_dicts, class_names) |
|
except AttributeError: |
|
pass |
|
|
|
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) |
|
return dataset_dicts |
|
|
|
|
|
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): |
|
if dataset is None: |
|
dataset = get_detection_dataset_dicts( |
|
cfg.DATASETS.TRAIN, |
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
|
if cfg.MODEL.KEYPOINT_ON |
|
else 0, |
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN |
|
if cfg.MODEL.LOAD_PROPOSALS |
|
else None, |
|
) |
|
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) |
|
|
|
if mapper is None: |
|
mapper = DatasetMapper(cfg, True) |
|
|
|
if "task" in dataset[0].keys(): |
|
shuffle = False |
|
else: |
|
shuffle = True |
|
|
|
if sampler is None: |
|
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN |
|
logger = logging.getLogger(__name__) |
|
logger.info("Using training sampler {}".format(sampler_name)) |
|
if sampler_name == "TrainingSampler": |
|
sampler = TrainingSampler(len(dataset), shuffle=shuffle) |
|
elif sampler_name == "RepeatFactorTrainingSampler": |
|
repeat_factors = ( |
|
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( |
|
dataset, cfg.DATALOADER.REPEAT_THRESHOLD |
|
) |
|
) |
|
sampler = RepeatFactorTrainingSampler(repeat_factors) |
|
elif sampler_name == "RandomSubsetTrainingSampler": |
|
sampler = RandomSubsetTrainingSampler( |
|
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO |
|
) |
|
else: |
|
raise ValueError("Unknown training sampler: {}".format(sampler_name)) |
|
|
|
return { |
|
"dataset": dataset, |
|
"sampler": sampler, |
|
"mapper": mapper, |
|
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH, |
|
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, |
|
"num_workers": cfg.DATALOADER.NUM_WORKERS, |
|
} |
|
|
|
|
|
|
|
@configurable(from_config=_train_loader_from_config) |
|
def build_detection_train_loader( |
|
dataset, |
|
*, |
|
mapper, |
|
sampler=None, |
|
total_batch_size, |
|
aspect_ratio_grouping=True, |
|
num_workers=0, |
|
): |
|
""" |
|
Build a dataloader for object detection with some default features. |
|
This interface is experimental. |
|
|
|
Args: |
|
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, |
|
or a map-style pytorch dataset. They can be obtained by using |
|
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. |
|
mapper (callable): a callable which takes a sample (dict) from dataset and |
|
returns the format to be consumed by the model. |
|
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. |
|
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces |
|
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, |
|
which coordinates an infinite random shuffle sequence across all workers. |
|
total_batch_size (int): total batch size across all workers. Batching |
|
simply puts data into a list. |
|
aspect_ratio_grouping (bool): whether to group images with similar |
|
aspect ratio for efficiency. When enabled, it requires each |
|
element in dataset be a dict with keys "width" and "height". |
|
num_workers (int): number of parallel data loading workers |
|
|
|
Returns: |
|
torch.utils.data.DataLoader: |
|
a dataloader. Each output from it is a ``list[mapped_element]`` of length |
|
``total_batch_size / num_workers``, where ``mapped_element`` is produced |
|
by the ``mapper``. |
|
""" |
|
if isinstance(dataset, list): |
|
dataset = DatasetFromList(dataset, copy=False) |
|
if mapper is not None: |
|
dataset = MapDataset(dataset, mapper) |
|
if sampler is None: |
|
sampler = TrainingSampler(len(dataset)) |
|
assert isinstance(sampler, torch.utils.data.sampler.Sampler) |
|
return build_batch_data_loader( |
|
dataset, |
|
sampler, |
|
total_batch_size, |
|
aspect_ratio_grouping=aspect_ratio_grouping, |
|
num_workers=num_workers, |
|
) |
|
|
|
|
|
def _test_loader_from_config(cfg, dataset_name, mapper=None): |
|
""" |
|
Uses the given `dataset_name` argument (instead of the names in cfg), because the |
|
standard practice is to evaluate each test set individually (not combining them). |
|
""" |
|
if isinstance(dataset_name, str): |
|
dataset_name = [dataset_name] |
|
|
|
dataset = get_detection_dataset_dicts( |
|
dataset_name, |
|
filter_empty=False, |
|
proposal_files=[ |
|
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] |
|
for x in dataset_name |
|
] |
|
if cfg.MODEL.LOAD_PROPOSALS |
|
else None, |
|
) |
|
if mapper is None: |
|
mapper = DatasetMapper(cfg, False) |
|
return { |
|
"dataset": dataset, |
|
"mapper": mapper, |
|
"num_workers": cfg.DATALOADER.NUM_WORKERS, |
|
"samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH, |
|
} |
|
|
|
|
|
@configurable(from_config=_test_loader_from_config) |
|
def build_detection_test_loader( |
|
dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1 |
|
): |
|
""" |
|
Similar to `build_detection_train_loader`, but uses a batch size of 1, |
|
and :class:`InferenceSampler`. This sampler coordinates all workers to |
|
produce the exact set of all samples. |
|
This interface is experimental. |
|
|
|
Args: |
|
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, |
|
or a map-style pytorch dataset. They can be obtained by using |
|
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. |
|
mapper (callable): a callable which takes a sample (dict) from dataset |
|
and returns the format to be consumed by the model. |
|
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. |
|
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces |
|
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, |
|
which splits the dataset across all workers. |
|
num_workers (int): number of parallel data loading workers |
|
|
|
Returns: |
|
DataLoader: a torch DataLoader, that loads the given detection |
|
dataset, with test-time transformation and batching. |
|
|
|
Examples: |
|
:: |
|
data_loader = build_detection_test_loader( |
|
DatasetRegistry.get("my_test"), |
|
mapper=DatasetMapper(...)) |
|
|
|
# or, instantiate with a CfgNode: |
|
data_loader = build_detection_test_loader(cfg, "my_test") |
|
""" |
|
if isinstance(dataset, list): |
|
dataset = DatasetFromList(dataset, copy=False) |
|
if mapper is not None: |
|
dataset = MapDataset(dataset, mapper) |
|
if sampler is None: |
|
sampler = InferenceSampler(len(dataset)) |
|
|
|
|
|
batch_sampler = torch.utils.data.sampler.BatchSampler( |
|
sampler, samples_per_gpu, drop_last=False |
|
) |
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
num_workers=num_workers, |
|
batch_sampler=batch_sampler, |
|
collate_fn=trivial_batch_collator, |
|
) |
|
return data_loader |
|
|
|
|
|
def dataset_sample_per_class(cfg): |
|
dataset_dicts = get_detection_dataset_dicts( |
|
cfg.DATASETS.TRAIN, |
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
|
if cfg.MODEL.KEYPOINT_ON |
|
else 0, |
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN |
|
if cfg.MODEL.LOAD_PROPOSALS |
|
else None, |
|
) |
|
if cfg.DATASETS.SAMPLE_PER_CLASS > 0: |
|
category_list = [data["category_id"] for data in dataset_dicts] |
|
category_count = Counter(category_list) |
|
category_group = { |
|
cat: [data for data in dataset_dicts if data["category_id"] == cat] |
|
for cat in category_count.keys() |
|
} |
|
|
|
rng = np.random.default_rng(cfg.DATASETS.SAMPLE_SEED) |
|
selected = { |
|
cat: groups |
|
if len(groups) < cfg.DATASETS.SAMPLE_PER_CLASS |
|
else rng.choice(groups, size=cfg.DATASETS.SAMPLE_PER_CLASS).tolist() |
|
for cat, groups in category_group.items() |
|
} |
|
tmp = [] |
|
for k, v in selected.items(): |
|
tmp.extend(v) |
|
dataset_dicts = tmp |
|
logger = logging.getLogger(__name__) |
|
|
|
dataset = dataset_dicts |
|
|
|
print_classification_instances_class_histogram( |
|
dataset, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).stuff_classes |
|
) |
|
return dataset |
|
|
|
|
|
|
|
def dataset_sample_per_task_class(cfg): |
|
dataset_dicts = get_detection_dataset_dicts( |
|
cfg.DATASETS.TRAIN, |
|
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, |
|
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE |
|
if cfg.MODEL.KEYPOINT_ON |
|
else 0, |
|
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN |
|
if cfg.MODEL.LOAD_PROPOSALS |
|
else None, |
|
) |
|
if cfg.DATASETS.SAMPLE_PER_CLASS > 0: |
|
category_list = [data["category_id"] for data in dataset_dicts if data["task"] == "sem_seg"] |
|
category_count = Counter(category_list) |
|
category_group = { |
|
cat: [data for data in dataset_dicts if data["category_id"] == cat and data["task"] == "sem_seg"] |
|
for cat in category_count.keys() |
|
} |
|
rng = np.random.default_rng(cfg.DATASETS.SAMPLE_SEED) |
|
selected = { |
|
cat: groups |
|
if len(groups) < cfg.DATASETS.SAMPLE_PER_CLASS |
|
else rng.choice(groups, size=cfg.DATASETS.SAMPLE_PER_CLASS).tolist() |
|
for cat, groups in category_group.items() |
|
} |
|
tmp = [] |
|
for k, v in selected.items(): |
|
tmp.extend(v) |
|
|
|
|
|
category_list = [data["category_id"] for data in dataset_dicts if data["task"] == "ins_seg"] |
|
category_count = Counter(category_list) |
|
category_group = { |
|
cat: [data for data in dataset_dicts if data["category_id"] == cat and data["task"] == "ins_seg"] |
|
for cat in category_count.keys() |
|
} |
|
rng = np.random.default_rng(cfg.DATASETS.SAMPLE_SEED) |
|
selected = { |
|
cat: groups |
|
if len(groups) < cfg.DATASETS.SAMPLE_PER_CLASS |
|
else rng.choice(groups, size=cfg.DATASETS.SAMPLE_PER_CLASS).tolist() |
|
for cat, groups in category_group.items() |
|
} |
|
for k, v in selected.items(): |
|
tmp.extend(v) |
|
|
|
|
|
category_list = [data["category_id"] for data in dataset_dicts if data["task"] == "pan_seg"] |
|
category_count = Counter(category_list) |
|
category_group = { |
|
cat: [data for data in dataset_dicts if data["category_id"] == cat and data["task"] == "pan_seg"] |
|
for cat in category_count.keys() |
|
} |
|
rng = np.random.default_rng(cfg.DATASETS.SAMPLE_SEED) |
|
selected = { |
|
cat: groups |
|
if len(groups) < cfg.DATASETS.SAMPLE_PER_CLASS |
|
else rng.choice(groups, size=cfg.DATASETS.SAMPLE_PER_CLASS).tolist() |
|
for cat, groups in category_group.items() |
|
} |
|
for k, v in selected.items(): |
|
tmp.extend(v) |
|
|
|
dataset_dicts = tmp |
|
logger = logging.getLogger(__name__) |
|
logger.info(tmp) |
|
dataset = dataset_dicts |
|
|
|
print_classification_instances_class_histogram( |
|
dataset, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).stuff_classes |
|
) |
|
return dataset |
|
|