Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import unittest | |
import torch | |
from detectron2.modeling.meta_arch import GeneralizedRCNN | |
from detectron2.utils.registry import _convert_target_to_string, locate | |
class A: | |
class B: | |
pass | |
class TestLocate(unittest.TestCase): | |
def _test_obj(self, obj): | |
name = _convert_target_to_string(obj) | |
newobj = locate(name) | |
self.assertIs(obj, newobj) | |
def test_basic(self): | |
self._test_obj(GeneralizedRCNN) | |
def test_inside_class(self): | |
# requires using __qualname__ instead of __name__ | |
self._test_obj(A.B) | |
def test_builtin(self): | |
self._test_obj(len) | |
self._test_obj(dict) | |
def test_pytorch_optim(self): | |
# pydoc.locate does not work for it | |
self._test_obj(torch.optim.SGD) | |
def test_failure(self): | |
with self.assertRaises(ImportError): | |
locate("asdf") | |
def test_compress_target(self): | |
from detectron2.data.transforms import RandomCrop | |
name = _convert_target_to_string(RandomCrop) | |
# name shouldn't contain 'augmentation_impl' | |
self.assertEqual(name, "detectron2.data.transforms.RandomCrop") | |
self.assertIs(RandomCrop, locate(name)) | |