File size: 5,694 Bytes
6f6830f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import unittest
import torch

from detectron2.modeling.poolers import ROIPooler
from detectron2.structures import Boxes, RotatedBoxes
from detectron2.utils.testing import random_boxes

logger = logging.getLogger(__name__)


class TestROIPooler(unittest.TestCase):
    def _test_roialignv2_roialignrotated_match(self, device):
        pooler_resolution = 14
        canonical_level = 4
        canonical_scale_factor = 2 ** canonical_level
        pooler_scales = (1.0 / canonical_scale_factor,)
        sampling_ratio = 0

        N, C, H, W = 2, 4, 10, 8
        N_rois = 10
        std = 11
        mean = 0
        feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean

        features = [feature.to(device)]

        rois = []
        rois_rotated = []
        for _ in range(N):
            boxes = random_boxes(N_rois, W * canonical_scale_factor)
            rotated_boxes = torch.zeros(N_rois, 5)
            rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
            rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
            rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
            rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
            rois.append(Boxes(boxes).to(device))
            rois_rotated.append(RotatedBoxes(rotated_boxes).to(device))

        roialignv2_pooler = ROIPooler(
            output_size=pooler_resolution,
            scales=pooler_scales,
            sampling_ratio=sampling_ratio,
            pooler_type="ROIAlignV2",
        )

        roialignv2_out = roialignv2_pooler(features, rois)

        roialignrotated_pooler = ROIPooler(
            output_size=pooler_resolution,
            scales=pooler_scales,
            sampling_ratio=sampling_ratio,
            pooler_type="ROIAlignRotated",
        )

        roialignrotated_out = roialignrotated_pooler(features, rois_rotated)

        self.assertTrue(torch.allclose(roialignv2_out, roialignrotated_out, atol=1e-4))

    def test_roialignv2_roialignrotated_match_cpu(self):
        self._test_roialignv2_roialignrotated_match(device="cpu")

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_roialignv2_roialignrotated_match_cuda(self):
        self._test_roialignv2_roialignrotated_match(device="cuda")

    def _test_scriptability(self, device):
        pooler_resolution = 14
        canonical_level = 4
        canonical_scale_factor = 2 ** canonical_level
        pooler_scales = (1.0 / canonical_scale_factor,)
        sampling_ratio = 0

        N, C, H, W = 2, 4, 10, 8
        N_rois = 10
        std = 11
        mean = 0
        feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean

        features = [feature.to(device)]

        rois = []
        for _ in range(N):
            boxes = random_boxes(N_rois, W * canonical_scale_factor)

            rois.append(Boxes(boxes).to(device))

        roialignv2_pooler = ROIPooler(
            output_size=pooler_resolution,
            scales=pooler_scales,
            sampling_ratio=sampling_ratio,
            pooler_type="ROIAlignV2",
        )

        roialignv2_out = roialignv2_pooler(features, rois)
        scripted_roialignv2_out = torch.jit.script(roialignv2_pooler)(features, rois)
        self.assertTrue(torch.equal(roialignv2_out, scripted_roialignv2_out))

    def test_scriptability_cpu(self):
        self._test_scriptability(device="cpu")

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_scriptability_gpu(self):
        self._test_scriptability(device="cuda")

    def test_no_images(self):
        N, C, H, W = 0, 32, 32, 32
        feature = torch.rand(N, C, H, W) - 0.5
        features = [feature]
        pooler = ROIPooler(
            output_size=14, scales=(1.0,), sampling_ratio=0.0, pooler_type="ROIAlignV2"
        )
        output = pooler.forward(features, [])
        self.assertEqual(output.shape, (0, C, 14, 14))

    def test_roi_pooler_tracing(self):
        class Model(torch.nn.Module):
            def __init__(self, roi):
                super(Model, self).__init__()
                self.roi = roi

            def forward(self, x, boxes):
                return self.roi(x, [Boxes(boxes)])

        pooler_resolution = 14
        canonical_level = 4
        canonical_scale_factor = 2 ** canonical_level
        pooler_scales = (1.0 / canonical_scale_factor, 0.5 / canonical_scale_factor)
        sampling_ratio = 0

        N, C, H, W = 1, 4, 10, 8
        N_rois = 10
        std = 11
        mean = 0
        feature = (torch.rand(N, C, H, W) - 0.5) * 2 * std + mean
        feature = [feature, feature]

        rois = random_boxes(N_rois, W * canonical_scale_factor)
        # Add one larger box so that this level has only one box.
        # This may trigger the bug https://github.com/pytorch/pytorch/issues/49852
        # that we shall workaround.
        rois = torch.cat([rois, torch.tensor([[0, 0, 448, 448]])])

        model = Model(
            ROIPooler(
                output_size=pooler_resolution,
                scales=pooler_scales,
                sampling_ratio=sampling_ratio,
                pooler_type="ROIAlign",
            )
        )

        with torch.no_grad():
            func = torch.jit.trace(model, (feature, rois))
            o = func(feature, rois)
            self.assertEqual(o.shape, (11, 4, 14, 14))
            o = func(feature, rois[:5])
            self.assertEqual(o.shape, (5, 4, 14, 14))
            o = func(feature, random_boxes(20, W * canonical_scale_factor))
            self.assertEqual(o.shape, (20, 4, 14, 14))


if __name__ == "__main__":
    unittest.main()