hasibzunair's picture
added files
4a285f6
raw history blame
No virus
800 Bytes
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import unittest
from torch.utils.data.sampler import SequentialSampler
from detectron2.data.samplers import GroupedBatchSampler
class TestGroupedBatchSampler(unittest.TestCase):
def test_missing_group_id(self):
sampler = SequentialSampler(list(range(100)))
group_ids = [1] * 100
samples = GroupedBatchSampler(sampler, group_ids, 2)
for mini_batch in samples:
self.assertEqual(len(mini_batch), 2)
def test_groups(self):
sampler = SequentialSampler(list(range(100)))
group_ids = [1, 0] * 50
samples = GroupedBatchSampler(sampler, group_ids, 2)
for mini_batch in samples:
self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0)