# Copyright (c) Facebook, Inc. and its affiliates. import unittest import torch from detectron2.modeling.meta_arch import GeneralizedRCNN from detectron2.utils.registry import _convert_target_to_string, locate class A: class B: pass class TestLocate(unittest.TestCase): def _test_obj(self, obj): name = _convert_target_to_string(obj) newobj = locate(name) self.assertIs(obj, newobj) def test_basic(self): self._test_obj(GeneralizedRCNN) def test_inside_class(self): # requires using __qualname__ instead of __name__ self._test_obj(A.B) def test_builtin(self): self._test_obj(len) self._test_obj(dict) def test_pytorch_optim(self): # pydoc.locate does not work for it self._test_obj(torch.optim.SGD) def test_failure(self): with self.assertRaises(ImportError): locate("asdf") def test_compress_target(self): from detectron2.data.transforms import RandomCrop name = _convert_target_to_string(RandomCrop) # name shouldn't contain 'augmentation_impl' self.assertEqual(name, "detectron2.data.transforms.RandomCrop") self.assertIs(RandomCrop, locate(name))