| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Unit tests for functions in box_utils.py.""" |
|
|
| from absl.testing import absltest |
| from absl.testing import parameterized |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scenic.model_lib.base_models import box_utils |
| from shapely import geometry |
|
|
|
|
| def sample_cxcywh_bbox(key, batch_shape): |
| """Samples a bounding box in the [cx, cy, w, h] in [0, 1] range format.""" |
| frac = 0.8 |
| sample = jax.random.uniform(key, shape=(*batch_shape, 4)) * frac |
| cx, cy, w, h = jnp.split(sample, indices_or_sections=4, axis=-1) |
| |
| w = jnp.where(cx + w / 2. >= 1., frac * 2. * (1. - cx), w) |
| h = jnp.where(cy + h / 2. >= 1., frac * 2. * (1. - cy), h) |
| |
| w = jnp.where(cx - w / 2. <= 0., frac * 2. * cx, w) |
| h = jnp.where(cy - h / 2. <= 0., frac * 2. * cy, h) |
|
|
| bbox = jnp.concatenate([cx, cy, w, h], axis=-1) |
| return bbox |
|
|
|
|
| class BoxUtilsTest(parameterized.TestCase): |
| """Tests all the bounding box related utilities.""" |
|
|
| def test_box_cxcywh_to_xyxy(self): |
| """Test for correctness of the box_cxcywh_to_xyxy operation.""" |
| cxcywh = jnp.array([[[0.1, 0.3, 0.2, 0.4], [0.1, 0.3, 0.2, 0.4]], |
| [[0.3, 0.2, 0.1, 0.4], [0.3, 0.2, 0.1, 0.4]]], |
| dtype=jnp.float32) |
| expected = jnp.array([[[0.0, 0.1, 0.2, 0.5], [0.0, 0.1, 0.2, 0.5]], |
| [[0.25, 0.0, 0.35, 0.4], [0.25, 0.0, 0.35, 0.4]]], |
| dtype=jnp.float32) |
| output = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| self.assertSequenceAlmostEqual( |
| expected.flatten(), output.flatten(), places=5) |
|
|
| |
| with self.assertRaises(ValueError): |
| cxcywh = jnp.array(np.random.uniform(size=(2, 3, 5))) |
| _ = box_utils.box_cxcywh_to_xyxy(cxcywh) |
|
|
| @parameterized.parameters([((3, 1, 4),), ((4, 6, 4),)]) |
| def test_box_cxcywh_to_xyxy_shape(self, input_shape): |
| """Test whether the shape is correct for box_cxcywh_to_xyxy.""" |
| cxcywh = jnp.array(np.random.uniform(size=input_shape)) |
| xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| self.assertEqual(xyxy.shape, cxcywh.shape) |
|
|
| @parameterized.parameters([((2, 5, 4),), ((1, 3, 4),)]) |
| def test_box_cxcy_to_xyxy_box_xyxy_to_cxcy(self, input_shape): |
| """Test both box conversion functions as they are inverses of each other.""" |
| cxcywh = jnp.array(np.random.uniform(size=input_shape)) |
| xyxy = box_utils.box_cxcywh_to_xyxy(cxcywh) |
| cxcywh_loop = box_utils.box_xyxy_to_cxcywh(xyxy) |
| self.assertSequenceAlmostEqual( |
| cxcywh_loop.flatten(), cxcywh.flatten(), places=5) |
|
|
|
|
| def sample_cxcywha(key, batch_shape): |
| """Sample rotated bounding boxes [cx, cy, w, h, a (radians)].""" |
| scale = jnp.array([0.3, 0.3, 0.5, 0.5, 1.0]) |
| offset = jnp.array([0.35, 0.35, 0, 0, 0]) |
| return jax.random.uniform(key, shape=(*batch_shape, 5)) * scale + offset |
|
|
|
|
| class RBoxUtilsTest(parameterized.TestCase): |
| """Tests all the rotated bounding box related utilities.""" |
|
|
| def test_convert_cxcywha_to_corners(self): |
| key = jax.random.PRNGKey(0) |
| cxcywha = sample_cxcywha(key, batch_shape=(300, 200)) |
| self.assertEqual(cxcywha.shape, (300, 200, 5)) |
|
|
| corners = box_utils.cxcywha_to_corners(cxcywha) |
| self.assertEqual(corners.shape, (300, 200, 4, 2)) |
| |
| self.assertTrue(jnp.all(corners >= 0)) |
| self.assertTrue(jnp.all(corners <= 1)) |
|
|
| def test_convert_corners_to_cxcywha(self): |
| key = jax.random.PRNGKey(0) |
| cxcywha = sample_cxcywha(key, batch_shape=(3, 2)) |
| self.assertEqual(cxcywha.shape, (3, 2, 5)) |
|
|
| corners = box_utils.cxcywha_to_corners(cxcywha) |
| cxcywha2 = box_utils.corners_to_cxcywha(corners) |
| np.testing.assert_allclose(cxcywha2, cxcywha, atol=1e-6) |
|
|
| def test_convert_cxcywha_to_corners_single_rotated(self): |
| cxcywha = jnp.array([1, 1, jnp.sqrt(2), jnp.sqrt(2), 45. * jnp.pi / 180.]) |
| corners = box_utils.cxcywha_to_corners(cxcywha) |
| expected_corners = [[1, 0], [2, 1], [1, 2], [0, 1]] |
| np.testing.assert_allclose(corners, expected_corners, atol=1e-7) |
|
|
| def test_intersect_line_segments(self): |
| """Test for correctness of the intersect_lines operation.""" |
| key = jax.random.PRNGKey(0) |
| key, subkey = jax.random.split(key) |
| lines1 = jax.random.uniform(subkey, (100, 2, 2)) |
| lines2 = jax.random.uniform(key, (100, 2, 2)) |
| intersect_line_segments = jax.jit( |
| jax.vmap(box_utils.intersect_line_segments)) |
| intersections = intersect_line_segments(lines1, lines2) |
| self.assertEqual(intersections.shape, (100, 2)) |
|
|
| expected_intersections = [] |
| for i in range(len(lines1)): |
| line1 = geometry.LineString(lines1[i]) |
| line2 = geometry.LineString(lines2[i]) |
| it = line1.intersection(line2) |
| it_coord = ( |
| it.coords[0] |
| if isinstance(it, geometry.Point) else jnp.asarray([jnp.nan] * 2)) |
| expected_intersections.append(it_coord) |
|
|
| np.testing.assert_allclose(intersections, expected_intersections, atol=1e-7) |
|
|
| def test_intersect_rbox_edges_same_box(self): |
| """Test for correctness of the intersect_rbox_edges operation.""" |
| rbox1 = jnp.array([0.5, 0.5, 1.0, 1.0, 0]) |
| rbox2 = rbox1 |
| corners1 = box_utils.cxcywha_to_corners(rbox1) |
| corners2 = box_utils.cxcywha_to_corners(rbox2) |
| it_points = box_utils.intersect_rbox_edges(corners1, corners2) |
| self.assertEqual(it_points.shape, (4, 4, 2)) |
| it_points = it_points[~jnp.any(jnp.isnan(it_points), -1)] |
| it_points = sorted([(x, y) for x, y in np.array(it_points)]) |
| expected_points = sorted([(0, 0), (0, 1), (1, 0), (1, 1)] * 2) |
| self.assertSequenceEqual(it_points, expected_points) |
|
|
| def test_intersect_rbox_edges_rotated_box(self): |
| """Test rboxe inscribes the other with 45 degree angle.""" |
| rbox1 = jnp.array([1.0, 1.0, 1.0, 1.0, 0]) |
| rbox2 = jnp.array([1.0, 1.0, jnp.sqrt(2), jnp.sqrt(2), 45. * np.pi / 180.]) |
| corners1 = box_utils.cxcywha_to_corners(rbox1) |
| corners2 = box_utils.cxcywha_to_corners(rbox2) |
| it_points = box_utils.intersect_rbox_edges(corners1, corners2) |
| it_points = jnp.round( |
| it_points[~jnp.any(jnp.isnan(it_points), -1)], decimals=4) |
| it_points = sorted([(x, y) for x, y in np.array(it_points)]) |
| |
| expected_pts = sorted([(1.5, 1.5), (1.5, 0.5), (0.5, 0.5), (0.5, 1.5)] * 2) |
| self.assertSequenceEqual(it_points, expected_pts) |
|
|
|
|
| class IoUTest(parameterized.TestCase): |
| """Test box_iou and generalized_box_iou functions.""" |
|
|
| def test_box_iou_values(self): |
| """Tests if 0 <= IoU <= 1 and -1 <= gIoU <=1.""" |
|
|
| |
| key = jax.random.PRNGKey(0) |
| key, subkey = jax.random.split(key) |
| pred_bbox = sample_cxcywh_bbox(key, batch_shape=(4, 100)) |
| tgt_bbox = sample_cxcywh_bbox(subkey, batch_shape=(4, 63)) |
|
|
| pred_bbox = box_utils.box_cxcywh_to_xyxy(pred_bbox) |
| tgt_bbox = box_utils.box_cxcywh_to_xyxy(tgt_bbox) |
|
|
| iou, union = box_utils.box_iou(pred_bbox, tgt_bbox, all_pairs=True) |
| self.assertTrue(jnp.all(iou >= 0)) |
| self.assertTrue(jnp.all(iou <= 1.)) |
| self.assertTrue(jnp.all(union >= 0.)) |
|
|
| giou = box_utils.generalized_box_iou(pred_bbox, tgt_bbox, all_pairs=True) |
| self.assertTrue(jnp.all(giou >= -1.)) |
| self.assertTrue(jnp.all(giou <= 1.)) |
|
|
| def test_box_iou(self): |
| """Test box_iou using hand designed targets.""" |
| in1 = jnp.array([ |
| [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]], |
| [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]], |
| [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| ], |
| dtype=jnp.float32) |
| in2 = jnp.array([ |
| [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]], |
| [[0.7, 0.4, 0.8, 0.6], [0.8, 0.6, 0.7, 0.4], [0.1, 0.1, 0.2, 0.2]], |
| [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| ], |
| dtype=jnp.float32) |
|
|
| target = jnp.array( |
| [[0.0, 0.125, 0.125], [0.0625, 0.0, 0.0], [0.0, 0.0, 1.0]], |
| dtype=jnp.float32) |
|
|
| output, _ = box_utils.box_iou(in1, in2, all_pairs=False) |
|
|
| self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3) |
|
|
| @classmethod |
| def _get_method_fn(cls, method): |
| """Returns method_fn function corresponding to method str.""" |
| if method == 'iou': |
| method_fn = lambda x, y, **kwargs: box_utils.box_iou(x, y, **kwargs)[0] |
| elif method == 'giou': |
| method_fn = box_utils.generalized_box_iou |
| else: |
| raise ValueError(f'Unknown method {method}') |
| return method_fn |
|
|
| @parameterized.parameters('iou', 'giou') |
| def test_all_pairs_true_false(self, method): |
| """Use *box_iou(..., all_pairs=False) to test the all_pairs=True case.""" |
| method_fn = self._get_method_fn(method) |
|
|
| in1 = jnp.array( |
| [ |
| [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]], |
| [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]], |
| ], |
| dtype=jnp.float32) |
| in2 = jnp.array([ |
| [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]], |
| [[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]], |
| ], |
| dtype=jnp.float32) |
|
|
| |
| in1_1 = jnp.array( |
| [ |
| [[0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.3, 0.4]], |
| [[0.6, 0.2, 1.0, 0.8], [0.6, 0.2, 1.0, 1.0]], |
| ], |
| dtype=jnp.float32) |
|
|
| out = method_fn(in1, in2, all_pairs=False) |
| out_1 = method_fn(in1_1, in2, all_pairs=False) |
|
|
| |
| out_all = method_fn(in1, in2, all_pairs=True) |
|
|
| |
| |
| |
| |
| |
| out_all_ = jnp.array([[[out[0, 0], out_1[0, 1]], [out_1[0, 0], out[0, 1]]], |
| [[out[1, 0], out_1[1, 1]], [out_1[1, 0], out[1, 1]]]], |
| dtype=jnp.float32) |
|
|
| self.assertSequenceAlmostEqual(out_all.flatten(), out_all_.flatten()) |
|
|
| def test_generalized_box_iou(self): |
| """Same as test_box_iou but for generalized_box_iou().""" |
| in1 = jnp.array([ |
| [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0], [0.1, 0.2, 0.5, 0.8]], |
| [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8], [0.0, 0.0, 0.0, 0.0]], |
| [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| ], |
| dtype=jnp.float32) |
| in2 = jnp.array([ |
| [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.7, 0.8]], |
| [[0.7, 0.4, 0.8, 0.6], [0.4, 0.4, 0.8, 0.6], [0.1, 0.1, 0.2, 0.2]], |
| [[0.0, 0.0, 0.0, 0.0], [0.2, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.2]], |
| ], |
| dtype=jnp.float32) |
|
|
| target_iou = jnp.array( |
| [[0.0, 0.125, 0.125], [0.0625, 1. / 7., 0.0], [0.0, 0.0, 1.0]], |
| dtype=jnp.float32) |
| target_extra = jnp.array( |
| [[-2. / 3., 0.0, -1. / 9.], [0.0, -2. / 9., -3. / 4.], [0.0, 0.0, 0.0]], |
| dtype=jnp.float32) |
| target = target_iou + target_extra |
|
|
| output = box_utils.generalized_box_iou(in1, in2, all_pairs=False) |
|
|
| self.assertSequenceAlmostEqual(output.flatten(), target.flatten(), places=3) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| @parameterized.parameters('iou', 'giou') |
| def test_backward(self, method): |
| """Test whether *box_iou methods have a grad.""" |
| method_fn = self._get_method_fn(method) |
|
|
| def loss_fn(x, y, all_pairs): |
| return method_fn(x, y, all_pairs=all_pairs).sum() |
|
|
| grad_fn = jax.grad(loss_fn) |
|
|
| in1 = jnp.array( |
| [ |
| [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.5, 1.0]], |
| [[0.6, 0.2, 1.0, 1.0], [0.6, 0.2, 1.0, 0.8]], |
| ], |
| dtype=jnp.float32) |
| in2 = jnp.array([ |
| [[0.4, 0.4, 0.5, 0.8], [0.4, 0.4, 0.5, 0.8]], |
| [[0.7, 0.4, 0.8, 0.6], [0.1, 0.5, 0.7, 0.7]], |
| ], |
| dtype=jnp.float32) |
|
|
| grad_in1 = grad_fn(in1, in2, all_pairs=True) |
| self.assertSequenceEqual(grad_in1.shape, in1.shape) |
|
|
| grad_in1 = grad_fn(in1, in2, all_pairs=False) |
| self.assertSequenceEqual(grad_in1.shape, in1.shape) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|