# Copyright (c) Facebook, Inc. and its affiliates. import unittest import torch from torch import Tensor from detectron2.export.torchscript import patch_instances from detectron2.structures import Boxes, Instances from detectron2.utils.testing import convert_scripted_instances class TestInstances(unittest.TestCase): def test_int_indexing(self): attr1 = torch.tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 0.5], [0.0, 0.0, 1.0], [0.0, 0.5, 0.5]]) attr2 = torch.tensor([0.1, 0.2, 0.3, 0.4]) instances = Instances((100, 100)) instances.attr1 = attr1 instances.attr2 = attr2 for i in range(-len(instances), len(instances)): inst = instances[i] self.assertEqual((inst.attr1 == attr1[i]).all(), True) self.assertEqual((inst.attr2 == attr2[i]).all(), True) self.assertRaises(IndexError, lambda: instances[len(instances)]) self.assertRaises(IndexError, lambda: instances[-len(instances) - 1]) def test_script_new_fields(self): def get_mask(x: Instances) -> torch.Tensor: return x.mask class f(torch.nn.Module): def forward(self, x: Instances): proposal_boxes = x.proposal_boxes # noqa F841 objectness_logits = x.objectness_logits # noqa F841 return x class g(torch.nn.Module): def forward(self, x: Instances): return get_mask(x) class g2(torch.nn.Module): def __init__(self): super().__init__() self.g = g() def forward(self, x: Instances): proposal_boxes = x.proposal_boxes # noqa F841 return x, self.g(x) fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor} with patch_instances(fields): torch.jit.script(f()) # can't script anymore after exiting the context with self.assertRaises(Exception): # will create a ConcreteType for g torch.jit.script(g2()) new_fields = {"mask": Tensor} with patch_instances(new_fields): # will compile g with a different Instances; this should pass torch.jit.script(g()) with self.assertRaises(Exception): torch.jit.script(g2()) new_fields = {"mask": Tensor, "proposal_boxes": Boxes} with patch_instances(new_fields) as NewInstances: # get_mask will be compiled with a different Instances; this should pass scripted_g2 = torch.jit.script(g2()) x = NewInstances((3, 4)) x.mask = torch.rand(3) x.proposal_boxes = Boxes(torch.rand(3, 4)) scripted_g2(x) # it should accept the new Instances object and run successfully def test_script_access_fields(self): class f(torch.nn.Module): def forward(self, x: Instances): proposal_boxes = x.proposal_boxes objectness_logits = x.objectness_logits return proposal_boxes.tensor + objectness_logits fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor} with patch_instances(fields): torch.jit.script(f()) def test_script_len(self): class f(torch.nn.Module): def forward(self, x: Instances): return len(x) class g(torch.nn.Module): def forward(self, x: Instances): return len(x) image_shape = (15, 15) fields = {"proposal_boxes": Boxes} with patch_instances(fields) as new_instance: script_module = torch.jit.script(f()) x = new_instance(image_shape) with self.assertRaises(Exception): script_module(x) box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) x.proposal_boxes = Boxes(box_tensors) length = script_module(x) self.assertEqual(length, 2) fields = {"objectness_logits": Tensor} with patch_instances(fields) as new_instance: script_module = torch.jit.script(g()) x = new_instance(image_shape) objectness_logits = torch.tensor([1.0]).reshape(1, 1) x.objectness_logits = objectness_logits length = script_module(x) self.assertEqual(length, 1) def test_script_has(self): class f(torch.nn.Module): def forward(self, x: Instances): return x.has("proposal_boxes") image_shape = (15, 15) fields = {"proposal_boxes": Boxes} with patch_instances(fields) as new_instance: script_module = torch.jit.script(f()) x = new_instance(image_shape) self.assertFalse(script_module(x)) box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) x.proposal_boxes = Boxes(box_tensors) self.assertTrue(script_module(x)) def test_script_to(self): class f(torch.nn.Module): def forward(self, x: Instances): return x.to(torch.device("cpu")) image_shape = (15, 15) fields = {"proposal_boxes": Boxes, "a": Tensor} with patch_instances(fields) as new_instance: script_module = torch.jit.script(f()) x = new_instance(image_shape) script_module(x) box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]]) x.proposal_boxes = Boxes(box_tensors) x.a = box_tensors script_module(x) def test_script_getitem(self): class f(torch.nn.Module): def forward(self, x: Instances, idx): return x[idx] image_shape = (15, 15) fields = {"proposal_boxes": Boxes, "a": Tensor} inst = Instances(image_shape) inst.proposal_boxes = Boxes(torch.rand(4, 4)) inst.a = torch.rand(4, 10) idx = torch.tensor([True, False, True, False]) with patch_instances(fields) as new_instance: script_module = torch.jit.script(f()) out = f()(inst, idx) out_scripted = script_module(new_instance.from_instances(inst), idx) self.assertTrue( torch.equal(out.proposal_boxes.tensor, out_scripted.proposal_boxes.tensor) ) self.assertTrue(torch.equal(out.a, out_scripted.a)) def test_from_to_instances(self): orig = Instances((30, 30)) orig.proposal_boxes = Boxes(torch.rand(3, 4)) fields = {"proposal_boxes": Boxes, "a": Tensor} with patch_instances(fields) as NewInstances: # convert to NewInstances and back new1 = NewInstances.from_instances(orig) new2 = convert_scripted_instances(new1) self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor)) self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor)) def test_script_init_args(self): def f(x: Tensor): image_shape = (15, 15) # __init__ can take arguments inst = Instances(image_shape, a=x, proposal_boxes=Boxes(x)) inst2 = Instances(image_shape, a=x) return inst.a, inst2.a fields = {"proposal_boxes": Boxes, "a": Tensor} with patch_instances(fields): script_f = torch.jit.script(f) x = torch.randn(3, 4) outputs = script_f(x) self.assertTrue(torch.equal(outputs[0], x)) self.assertTrue(torch.equal(outputs[1], x)) def test_script_cat(self): def f(x: Tensor): image_shape = (15, 15) # __init__ can take arguments inst = Instances(image_shape, a=x) inst2 = Instances(image_shape, a=x) inst3 = Instances(image_shape, proposal_boxes=Boxes(x)) return inst.cat([inst, inst2]), inst3.cat([inst3, inst3]) fields = {"proposal_boxes": Boxes, "a": Tensor} with patch_instances(fields): script_f = torch.jit.script(f) x = torch.randn(3, 4) output, output2 = script_f(x) self.assertTrue(torch.equal(output.a, torch.cat([x, x]))) self.assertFalse(output.has("proposal_boxes")) self.assertTrue(torch.equal(output2.proposal_boxes.tensor, torch.cat([x, x]))) if __name__ == "__main__": unittest.main()