# Copyright (c) Facebook, Inc. and its affiliates. import itertools import math import operator import unittest import torch from torch.utils import data from torch.utils.data.sampler import SequentialSampler from detectron2.data.build import worker_init_reset_seed from detectron2.data.common import DatasetFromList, ToIterableDataset from detectron2.data.samplers import ( GroupedBatchSampler, InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler, ) from detectron2.utils.env import seed_all_rng class TestGroupedBatchSampler(unittest.TestCase): def test_missing_group_id(self): sampler = SequentialSampler(list(range(100))) group_ids = [1] * 100 samples = GroupedBatchSampler(sampler, group_ids, 2) for mini_batch in samples: self.assertEqual(len(mini_batch), 2) def test_groups(self): sampler = SequentialSampler(list(range(100))) group_ids = [1, 0] * 50 samples = GroupedBatchSampler(sampler, group_ids, 2) for mini_batch in samples: self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0) class TestSamplerDeterministic(unittest.TestCase): def test_to_iterable(self): sampler = TrainingSampler(100, seed=10) gt_output = list(itertools.islice(sampler, 100)) self.assertEqual(set(gt_output), set(range(100))) dataset = DatasetFromList(list(range(100))) dataset = ToIterableDataset(dataset, sampler) data_loader = data.DataLoader(dataset, num_workers=0, collate_fn=operator.itemgetter(0)) output = list(itertools.islice(data_loader, 100)) self.assertEqual(output, gt_output) data_loader = data.DataLoader( dataset, num_workers=2, collate_fn=operator.itemgetter(0), worker_init_fn=worker_init_reset_seed, # reset seed should not affect behavior of TrainingSampler ) output = list(itertools.islice(data_loader, 100)) # multiple workers should not lead to duplicate or different data self.assertEqual(output, gt_output) def test_training_sampler_seed(self): seed_all_rng(42) sampler = TrainingSampler(30) data = list(itertools.islice(sampler, 65)) seed_all_rng(42) sampler = TrainingSampler(30) seed_all_rng(999) # should be ineffective data2 = list(itertools.islice(sampler, 65)) self.assertEqual(data, data2) class TestRepeatFactorTrainingSampler(unittest.TestCase): def test_repeat_factors_from_category_frequency(self): repeat_thresh = 0.5 dataset_dicts = [ {"annotations": [{"category_id": 0}, {"category_id": 1}]}, {"annotations": [{"category_id": 0}]}, {"annotations": []}, ] rep_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( dataset_dicts, repeat_thresh ) expected_rep_factors = torch.tensor([math.sqrt(3 / 2), 1.0, 1.0]) self.assertTrue(torch.allclose(rep_factors, expected_rep_factors)) class TestInferenceSampler(unittest.TestCase): def test_local_indices(self): sizes = [0, 16, 2, 42] world_sizes = [5, 2, 3, 4] expected_results = [ [range(0) for _ in range(5)], [range(8), range(8, 16)], [range(1), range(1, 2), range(0)], [range(11), range(11, 22), range(22, 32), range(32, 42)], ] for size, world_size, expected_result in zip(sizes, world_sizes, expected_results): with self.subTest(f"size={size}, world_size={world_size}"): local_indices = [ InferenceSampler._get_local_indices(size, world_size, r) for r in range(world_size) ] self.assertEqual(local_indices, expected_result)