rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
7.51 kB
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import pickle
import sys
import unittest
from functools import partial
import torch
from iopath.common.file_io import LazyPath
from detectron2 import model_zoo
from detectron2.config import get_cfg, instantiate
from detectron2.data import (
DatasetCatalog,
DatasetFromList,
MapDataset,
ToIterableDataset,
build_batch_data_loader,
build_detection_test_loader,
build_detection_train_loader,
)
from detectron2.data.common import (
AspectRatioGroupedDataset,
set_default_dataset_from_list_serialize_method,
)
from detectron2.data.samplers import InferenceSampler, TrainingSampler
def _a_slow_func(x):
return "path/{}".format(x)
class TestDatasetFromList(unittest.TestCase):
# Failing for py3.6, likely due to pickle
@unittest.skipIf(sys.version_info.minor <= 6, "Not supported in Python 3.6")
def test_using_lazy_path(self):
dataset = []
for i in range(10):
dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))})
dataset = DatasetFromList(dataset)
for i in range(10):
path = dataset[i]["file_name"]
self.assertTrue(isinstance(path, LazyPath))
self.assertEqual(os.fspath(path), _a_slow_func(i))
def test_alternative_serialize_method(self):
dataset = [1, 2, 3]
dataset = DatasetFromList(dataset, serialize=torch.tensor)
self.assertEqual(dataset[2], torch.tensor(3))
def test_change_default_serialize_method(self):
dataset = [1, 2, 3]
with set_default_dataset_from_list_serialize_method(torch.tensor):
dataset_1 = DatasetFromList(dataset, serialize=True)
self.assertEqual(dataset_1[2], torch.tensor(3))
dataset_2 = DatasetFromList(dataset, serialize=True)
self.assertEqual(dataset_2[2], 3)
class TestMapDataset(unittest.TestCase):
@staticmethod
def map_func(x):
if x == 2:
return None
return x * 2
def test_map_style(self):
ds = DatasetFromList([1, 2, 3])
ds = MapDataset(ds, TestMapDataset.map_func)
self.assertEqual(ds[0], 2)
self.assertEqual(ds[2], 6)
self.assertIn(ds[1], [2, 6])
def test_iter_style(self):
class DS(torch.utils.data.IterableDataset):
def __iter__(self):
yield from [1, 2, 3]
ds = DS()
ds = MapDataset(ds, TestMapDataset.map_func)
self.assertIsInstance(ds, torch.utils.data.IterableDataset)
data = list(iter(ds))
self.assertEqual(data, [2, 6])
def test_pickleability(self):
ds = DatasetFromList([1, 2, 3])
ds = MapDataset(ds, lambda x: x * 2)
ds = pickle.loads(pickle.dumps(ds))
self.assertEqual(ds[0], 2)
class TestAspectRatioGrouping(unittest.TestCase):
def test_reiter_leak(self):
data = [(1, 0), (0, 1), (1, 0), (0, 1)]
data = [{"width": a, "height": b} for (a, b) in data]
batchsize = 2
dataset = AspectRatioGroupedDataset(data, batchsize)
for _ in range(5):
for idx, __ in enumerate(dataset):
if idx == 1:
# manually break, so the iterator does not stop by itself
break
# check that bucket sizes are valid
for bucket in dataset._buckets:
self.assertLess(len(bucket), batchsize)
class _MyData(torch.utils.data.IterableDataset):
def __iter__(self):
while True:
yield 1
class TestDataLoader(unittest.TestCase):
def _get_kwargs(self):
# get kwargs of build_detection_train_loader
cfg = model_zoo.get_config("common/data/coco.py").dataloader.train
cfg.dataset.names = "coco_2017_val_100"
cfg.pop("_target_")
kwargs = {k: instantiate(v) for k, v in cfg.items()}
return kwargs
def test_build_dataloader_train(self):
kwargs = self._get_kwargs()
dl = build_detection_train_loader(**kwargs)
next(iter(dl))
def test_build_iterable_dataloader_train(self):
kwargs = self._get_kwargs()
ds = DatasetFromList(kwargs.pop("dataset"))
ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
dl = build_detection_train_loader(dataset=ds, **kwargs)
next(iter(dl))
def test_build_iterable_dataloader_from_cfg(self):
cfg = get_cfg()
cfg.DATASETS.TRAIN = ["iter_data"]
DatasetCatalog.register("iter_data", lambda: _MyData())
dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False)
next(iter(dl))
dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x)
next(iter(dl))
def _check_is_range(self, data_loader, N):
# check that data_loader produces range(N)
data = list(iter(data_loader))
data = [x for batch in data for x in batch] # flatten the batches
self.assertEqual(len(data), N)
self.assertEqual(set(data), set(range(N)))
def test_build_batch_dataloader_inference(self):
# Test that build_batch_data_loader can be used for inference
N = 96
ds = DatasetFromList(list(range(N)))
sampler = InferenceSampler(len(ds))
dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
self._check_is_range(dl, N)
def test_build_batch_dataloader_inference_incomplete_batch(self):
# Test that build_batch_data_loader works when dataset size is not multiple of
# batch size or num_workers
def _test(N, batch_size, num_workers):
ds = DatasetFromList(list(range(N)))
sampler = InferenceSampler(len(ds))
dl = build_batch_data_loader(ds, sampler, batch_size, num_workers=num_workers)
data = list(iter(dl))
self.assertEqual(len(data), len(dl)) # floor(N / batch_size)
self._check_is_range(dl, N // batch_size * batch_size)
dl = build_batch_data_loader(
ds, sampler, batch_size, num_workers=num_workers, drop_last=False
)
data = list(iter(dl))
self.assertEqual(len(data), len(dl)) # ceil(N / batch_size)
self._check_is_range(dl, N)
_test(48, batch_size=8, num_workers=3)
_test(47, batch_size=8, num_workers=3)
_test(46, batch_size=8, num_workers=3)
_test(40, batch_size=8, num_workers=3)
_test(39, batch_size=8, num_workers=3)
def test_build_dataloader_inference(self):
N = 50
ds = DatasetFromList(list(range(N)))
sampler = InferenceSampler(len(ds))
# test that parallel loader works correctly
dl = build_detection_test_loader(
dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3
)
self._check_is_range(dl, N)
# test that batch_size works correctly
dl = build_detection_test_loader(
dataset=ds, sampler=sampler, mapper=lambda x: x, batch_size=4, num_workers=0
)
self._check_is_range(dl, N)
def test_build_iterable_dataloader_inference(self):
# Test that build_detection_test_loader supports iterable dataset
N = 50
ds = DatasetFromList(list(range(N)))
ds = ToIterableDataset(ds, InferenceSampler(len(ds)))
dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3)
self._check_is_range(dl, N)