Spaces:
Build error
Build error
File size: 7,635 Bytes
1865436 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) Facebook, Inc. and its affiliates.
import unittest
import torch
from torch import Tensor
from detectron2.export.torchscript import patch_instances
from detectron2.structures import Boxes, Instances
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.testing import convert_scripted_instances
class TestInstances(unittest.TestCase):
def test_int_indexing(self):
attr1 = torch.tensor([[0.0, 0.0, 1.0], [0.0, 0.0, 0.5], [0.0, 0.0, 1.0], [0.0, 0.5, 0.5]])
attr2 = torch.tensor([0.1, 0.2, 0.3, 0.4])
instances = Instances((100, 100))
instances.attr1 = attr1
instances.attr2 = attr2
for i in range(-len(instances), len(instances)):
inst = instances[i]
self.assertEqual((inst.attr1 == attr1[i]).all(), True)
self.assertEqual((inst.attr2 == attr2[i]).all(), True)
self.assertRaises(IndexError, lambda: instances[len(instances)])
self.assertRaises(IndexError, lambda: instances[-len(instances) - 1])
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_script_new_fields(self):
def get_mask(x: Instances) -> torch.Tensor:
return x.mask
class f(torch.nn.Module):
def forward(self, x: Instances):
proposal_boxes = x.proposal_boxes # noqa F841
objectness_logits = x.objectness_logits # noqa F841
return x
class g(torch.nn.Module):
def forward(self, x: Instances):
return get_mask(x)
class g2(torch.nn.Module):
def __init__(self):
super().__init__()
self.g = g()
def forward(self, x: Instances):
proposal_boxes = x.proposal_boxes # noqa F841
return x, self.g(x)
fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor}
with patch_instances(fields):
torch.jit.script(f())
# can't script anymore after exiting the context
with self.assertRaises(Exception):
# will create a ConcreteType for g
torch.jit.script(g2())
new_fields = {"mask": Tensor}
with patch_instances(new_fields):
# will compile g with a different Instances; this should pass
torch.jit.script(g())
with self.assertRaises(Exception):
torch.jit.script(g2())
new_fields = {"mask": Tensor, "proposal_boxes": Boxes}
with patch_instances(new_fields) as NewInstances:
# get_mask will be compiled with a different Instances; this should pass
scripted_g2 = torch.jit.script(g2())
x = NewInstances((3, 4))
x.mask = torch.rand(3)
x.proposal_boxes = Boxes(torch.rand(3, 4))
scripted_g2(x) # it should accept the new Instances object and run successfully
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_script_access_fields(self):
class f(torch.nn.Module):
def forward(self, x: Instances):
proposal_boxes = x.proposal_boxes
objectness_logits = x.objectness_logits
return proposal_boxes.tensor + objectness_logits
fields = {"proposal_boxes": Boxes, "objectness_logits": Tensor}
with patch_instances(fields):
torch.jit.script(f())
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_script_len(self):
class f(torch.nn.Module):
def forward(self, x: Instances):
return len(x)
class g(torch.nn.Module):
def forward(self, x: Instances):
return len(x)
image_shape = (15, 15)
fields = {"proposal_boxes": Boxes}
with patch_instances(fields) as new_instance:
script_module = torch.jit.script(f())
x = new_instance(image_shape)
with self.assertRaises(Exception):
script_module(x)
box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]])
x.proposal_boxes = Boxes(box_tensors)
length = script_module(x)
self.assertEqual(length, 2)
fields = {"objectness_logits": Tensor}
with patch_instances(fields) as new_instance:
script_module = torch.jit.script(g())
x = new_instance(image_shape)
objectness_logits = torch.tensor([1.0]).reshape(1, 1)
x.objectness_logits = objectness_logits
length = script_module(x)
self.assertEqual(length, 1)
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_script_has(self):
class f(torch.nn.Module):
def forward(self, x: Instances):
return x.has("proposal_boxes")
image_shape = (15, 15)
fields = {"proposal_boxes": Boxes}
with patch_instances(fields) as new_instance:
script_module = torch.jit.script(f())
x = new_instance(image_shape)
self.assertFalse(script_module(x))
box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]])
x.proposal_boxes = Boxes(box_tensors)
self.assertTrue(script_module(x))
@unittest.skipIf(TORCH_VERSION < (1, 8), "Insufficient pytorch version")
def test_script_to(self):
class f(torch.nn.Module):
def forward(self, x: Instances):
return x.to(torch.device("cpu"))
image_shape = (15, 15)
fields = {"proposal_boxes": Boxes, "a": Tensor}
with patch_instances(fields) as new_instance:
script_module = torch.jit.script(f())
x = new_instance(image_shape)
script_module(x)
box_tensors = torch.tensor([[5, 5, 10, 10], [1, 1, 2, 3]])
x.proposal_boxes = Boxes(box_tensors)
x.a = box_tensors
script_module(x)
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_script_getitem(self):
class f(torch.nn.Module):
def forward(self, x: Instances, idx):
return x[idx]
image_shape = (15, 15)
fields = {"proposal_boxes": Boxes, "a": Tensor}
inst = Instances(image_shape)
inst.proposal_boxes = Boxes(torch.rand(4, 4))
inst.a = torch.rand(4, 10)
idx = torch.tensor([True, False, True, False])
with patch_instances(fields) as new_instance:
script_module = torch.jit.script(f())
out = f()(inst, idx)
out_scripted = script_module(new_instance.from_instances(inst), idx)
self.assertTrue(
torch.equal(out.proposal_boxes.tensor, out_scripted.proposal_boxes.tensor)
)
self.assertTrue(torch.equal(out.a, out_scripted.a))
@unittest.skipIf(TORCH_VERSION < (1, 7), "Insufficient pytorch version")
def test_from_to_instances(self):
orig = Instances((30, 30))
orig.proposal_boxes = Boxes(torch.rand(3, 4))
fields = {"proposal_boxes": Boxes, "a": Tensor}
with patch_instances(fields) as NewInstances:
# convert to NewInstances and back
new1 = NewInstances.from_instances(orig)
new2 = convert_scripted_instances(new1)
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new1.proposal_boxes.tensor))
self.assertTrue(torch.equal(orig.proposal_boxes.tensor, new2.proposal_boxes.tensor))
if __name__ == "__main__":
unittest.main()
|