| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Unit tests for functions in model_utils.py.""" |
| import itertools |
|
|
| from absl.testing import absltest |
| from absl.testing import parameterized |
| from flax.training import common_utils |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from scenic.model_lib.base_models import model_utils |
|
|
|
|
| class SimpleGatherTest(parameterized.TestCase): |
| """Test simple_gather().""" |
|
|
| def test_simple_gather_ndarray(self): |
| """Test against manually specified target when idx is a nd-array.""" |
| x = jnp.array(np.random.normal(size=(2, 3, 5)), dtype=jnp.float32) |
| idx = jnp.array([[1, 0, 2], [2, 1, 0]], dtype=jnp.int32) |
| y = model_utils.simple_gather(x, idx) |
| y_target = jnp.stack([ |
| jnp.stack([x[0, 1], x[0, 0], x[0, 2]]), |
| jnp.stack([x[1, 2], x[1, 1], x[1, 0]])]) |
|
|
| self.assertSequenceAlmostEqual(y.flatten(), y_target.flatten()) |
|
|
|
|
| class LossTest(parameterized.TestCase): |
| """Test various loss functions in model_utils.""" |
|
|
| def test_weighted_l1_loss(self): |
| """Test weighted_l1_loss against a manually specified target.""" |
| x = jnp.array([[0.1, 0.3], [-1.0, 0.2]], dtype=jnp.float32) |
| y = jnp.array([[0.5, -1.3], [0.9, 1.2]], dtype=jnp.float32) |
|
|
| out1 = model_utils.weighted_l1_loss(x, y, reduction=None) |
| out1_target = jnp.array([[0.4, 1.6], [1.9, 1.0]], dtype=jnp.float32) |
| self.assertSequenceAlmostEqual( |
| out1.flatten(), out1_target.flatten(), places=5) |
|
|
| out2 = model_utils.weighted_l1_loss(x, y, reduction='mean').item() |
| out2_target = 4.9 / 4 |
| self.assertAlmostEqual(out2, out2_target, places=5) |
|
|
| def test_weighted_box_l1_loss(self): |
| """Test weighted_box_l1_loss against manually specified targets.""" |
| x1 = jnp.array([[0.1, 0.3, 0.9, 0.8]], dtype=jnp.float32) |
| y1 = jnp.array([[0.5, 0.1, 0.9, 0.7]], dtype=jnp.float32) |
|
|
| out1 = model_utils.weighted_box_l1_loss(x1, y1) |
| out1_target = jnp.array([[0.4, 0.2, 0, 0.1]], dtype=jnp.float32) |
| self.assertSequenceAlmostEqual( |
| out1.flatten(), out1_target.flatten(), places=5) |
|
|
| out2 = model_utils.weighted_box_l1_loss(x1, y1, reduction='mean').item() |
| out2_target = jnp.mean(out1_target).item() |
| self.assertAlmostEqual(out2, out2_target, places=5) |
|
|
| out3 = model_utils.weighted_box_l1_loss(x1, y1, tight=False) |
| out3_target = jnp.array([[0.4, 0.0, 0.0, 0.1]], dtype=jnp.float32) |
| self.assertSequenceAlmostEqual( |
| out3.flatten(), out3_target.flatten(), places=5) |
|
|
| def test_weighted_sigmoid_cross_entropy(self): |
| """Tests weighted_sigmoid_cross_entropy.""" |
|
|
| logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32) |
| labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32) |
| sigmoid = jax.nn.sigmoid |
| log = jnp.log |
|
|
| loss = model_utils.weighted_sigmoid_cross_entropy(logits, labels) |
| gt_loss = jnp.array([[ |
| -log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.)) |
| ], [-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))] |
| ]) / np.prod(labels.shape[:-1]) |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), gt_loss.sum().flatten(), places=3) |
|
|
| example_weights = jnp.array([1., 0.]) |
| loss = model_utils.weighted_sigmoid_cross_entropy( |
| logits, labels, weights=example_weights) |
| gt_loss = jnp.array([[ |
| -log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.)) |
| ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), gt_loss.sum().flatten(), places=3) |
|
|
| label_weights = jnp.array([1., 2., 3.]) |
| loss = model_utils.weighted_sigmoid_cross_entropy( |
| logits, labels, label_weights=label_weights) |
| gt_loss = jnp.array([[ |
| -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| ], [-log(sigmoid(4.)), -2 * log(1 - sigmoid(5.)), -3 * log(sigmoid(6.))] |
| ]) / np.prod(labels.shape[:-1]) |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), gt_loss.sum().flatten(), places=3) |
|
|
| loss = model_utils.weighted_sigmoid_cross_entropy( |
| logits, labels, weights=example_weights, label_weights=label_weights) |
| gt_loss = jnp.array([[ |
| -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), gt_loss.sum().flatten(), places=3) |
|
|
| |
| |
| label_weights = jnp.array([[1., 2., 3.], [4., 5., 6.]]) |
| loss = model_utils.weighted_sigmoid_cross_entropy( |
| logits, labels, weights=example_weights, label_weights=label_weights) |
| gt_loss = jnp.array([[ |
| -log(1 - sigmoid(1.)), -2 * log(sigmoid(2.)), -3 * log(sigmoid(3.)) |
| ], [0., 0., 0.]]) / example_weights.sum() + 1e-9 |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), gt_loss.sum().flatten(), places=3) |
|
|
| with self.assertRaises(ValueError): |
| label_weights = jnp.array([1., 2., 3., 4.]) |
| loss = model_utils.weighted_sigmoid_cross_entropy( |
| logits, labels, label_weights=label_weights) |
|
|
| def test_focal_sigmoid_cross_entropy(self): |
| """Tests focal_sigmoid_cross_entropy.""" |
| logits = jnp.array([[1, 2, 3], [4, 5, 6]], dtype=jnp.float32) |
| labels = jnp.array([[0, 1, 1], [1, 0, 1]], dtype=jnp.float32) |
| sigmoid = jax.nn.sigmoid |
| log = jnp.log |
|
|
| a = 0.25 |
| g = 2. |
| loss = model_utils.focal_sigmoid_cross_entropy( |
| logits, labels, alpha=a, gamma=g) |
|
|
| gt_loss = jnp.array( |
| [[-log(1 - sigmoid(1.)), -log(sigmoid(2.)), -log(sigmoid(3.))], |
| [-log(sigmoid(4.)), -log(1 - sigmoid(5.)), -log(sigmoid(6.))]]) |
| focal_factor = jnp.array([[ |
| (1 - a) * sigmoid(1.)**g, a * sigmoid(-2.)**g, a * sigmoid(-3.)**g |
| ], [a * sigmoid(-4.)**g, (1 - a) * sigmoid(5.)**g, a * sigmoid(-6.)**g]]) |
| self.assertSequenceAlmostEqual( |
| loss.flatten(), (gt_loss * focal_factor).flatten(), places=3) |
|
|
| def test_dice_loss(self): |
| """Tests the correctness of the segmentation dice loss.""" |
| |
| batch, num_objects, h, w = 1, 2, 128, 128 |
| stride = 2 |
| targets = np.zeros((batch, num_objects, h, w), dtype=np.float32) |
| targets[0, 0, :64, :64] = 1.0 |
| targets[0, 1, 64:, 64:] = 1.0 |
| input_shape = batch, num_objects, h // stride, w // stride |
|
|
| |
| inputs = np.zeros(input_shape, dtype=np.float32) |
| inputs[0, 0, :64 // stride, :64 // stride] = 1.0 |
| inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| inputs = (inputs - 0.5) * 1e6 |
| loss = model_utils.dice_loss( |
| jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| np.testing.assert_array_almost_equal(loss, [[0.0, 0.0]], decimal=3) |
|
|
| |
| inputs = np.zeros(input_shape, dtype=np.float32) |
| inputs[0, 0, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0 |
| inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| inputs = (inputs - 0.5) * 1e6 |
| loss = model_utils.dice_loss( |
| jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| np.testing.assert_array_almost_equal(loss, [[0.5, 0.0]], decimal=3) |
|
|
| |
| inputs = np.zeros(input_shape, dtype=np.float32) |
| inputs[0, 0, 64 // stride:, 64 // stride:] = 1.0 |
| inputs[0, 1, 64 // stride:, 64 // stride:] = 1.0 |
| inputs = (inputs - 0.5) * 1e6 |
| loss = model_utils.dice_loss( |
| jnp.array(inputs), jnp.array(targets), interpolation='nearest') |
| np.testing.assert_array_almost_equal(loss, [[1.0, 0.0]], decimal=3) |
|
|
| |
| inputs = np.zeros((batch, 3, h // stride, w // stride), dtype=np.float32) |
| inputs[0, 0, :64 // stride, :64 // stride] = 1.0 |
| inputs[0, 1, 32 // stride:(32 + 64) // stride, :64 // stride] = 1.0 |
| inputs[0, 2, 64 // stride:, 64 // stride:] = 1.0 |
| inputs = (inputs - 0.5) * 1e6 |
| loss = model_utils.dice_loss( |
| jnp.array(inputs), jnp.array(targets), interpolation='nearest', |
| all_pairs=True) |
| self.assertTupleEqual(loss.shape, (1, 3, 2)) |
| np.testing.assert_array_almost_equal(loss, [[[0.0, 1.0], |
| [0.5, 1.0], |
| [1.0, 0.0]]], decimal=3) |
|
|
| def test_weighted_square_error(self): |
| """Tests implementation of squared error.""" |
|
|
| predictions = jnp.array([ |
| [ |
| [1.0, 3.0, 5.0, 6.0], |
| [3.0, 5.0, 11.0, 10.0], |
| [9.0, 10.0, 11.0, 12.0], |
| [14.0, 13.0, 14.0, 17.0], |
| ], |
| [ |
| [17.0, 18.0, 21.0, 22.0], |
| [20.0, 19.0, 24.0, 25.0], |
| [27.0, 29.0, 30.0, 32.0], |
| [27.0, 28.0, 33.0, 32.0], |
| ], |
| ]) |
| targets = jnp.arange(1, 33).reshape(2, 4, 4) |
|
|
| |
| loss = model_utils.weighted_mean_squared_error(predictions, targets) |
| expected_loss = jnp.mean(jnp.array([38.0, 70.0])) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| |
| |
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=(1, 2)) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=(-1, -2)) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=(2, 1)) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| |
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=-1) |
| expected_loss = jnp.mean(jnp.array([[9, 25, 0, 4], |
| [8, 12, 38, 12]])) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=2) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| axis=1) |
| expected_loss = jnp.mean(jnp.array([[5, 3, 21, 9], |
| [9, 22, 18, 21]])) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| |
| weights = jnp.array([[1, 1, 1, 0], [0, 1, 1, 0]]) |
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| weights, axis=-1) |
| expected_loss = jnp.mean(jnp.array([9, 25, 12, 38, 0])) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
| weights = jnp.array([1, 0]) |
| loss = model_utils.weighted_mean_squared_error(predictions, targets, |
| weights, axis=-1) |
| expected_loss = jnp.mean(jnp.array([9, 25, 0, 4])) |
| self.assertAlmostEqual(loss, expected_loss, places=5) |
|
|
|
|
| class MetricTest(parameterized.TestCase): |
| """Tests the metric computation related utilities.""" |
|
|
| def is_valid(self, t, value_name): |
| """Helper function to assert that tensor `t` does not have `nan`, `inf`.""" |
| self.assertFalse( |
| jnp.isnan(t).any(), msg=f'Found nan\'s in {t} for {value_name}') |
| self.assertFalse( |
| jnp.isinf(t).any(), msg=f'Found inf\'s in {t} for {value_name}') |
|
|
| def test_weighted_topk_correctly_classified(self): |
| """Tests the topk accuracy computation.""" |
| batch_size = 512 |
| num_of_classes = 100 |
| logits = jnp.array( |
| np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| labels = jnp.array(np.random.randint(num_of_classes, size=(batch_size,))) |
|
|
| one_hot_targets = common_utils.onehot(labels, logits.shape[-1]) |
| classification_accuracy = model_utils.weighted_correctly_classified( |
| logits, one_hot_targets) |
| top_one_accuracy = model_utils.weighted_topk_correctly_classified( |
| logits, one_hot_targets, k=1) |
| self.assertSequenceAlmostEqual( |
| classification_accuracy.flatten(), top_one_accuracy.flatten()) |
|
|
| top_n_accuracy = model_utils.weighted_topk_correctly_classified( |
| logits, one_hot_targets, k=num_of_classes) |
| self.assertEqual(jnp.mean(top_n_accuracy), 1) |
|
|
| |
| top_5_accuracy = model_utils.weighted_topk_correctly_classified( |
| logits, one_hot_targets, k=5) |
| top5_pred = np.argsort( |
| np.reshape(logits, [-1, num_of_classes]), axis=1)[:, -5:] |
| y_true = np.array(labels) |
| top5_pred = np.reshape(top5_pred, [-1, 5]) |
| y_true = np.reshape(y_true, [-1]) |
| np_top_accuracy = np.array( |
| [y_true[i] in top5_pred[i, :] for i in range(len(y_true))]) |
| self.assertSequenceAlmostEqual(top_5_accuracy.flatten(), |
| np_top_accuracy.flatten()) |
|
|
| def test_weighted_recall(self): |
| """Tests the topk recall computation.""" |
|
|
| logits = np.array([[[2, 3, 4], |
| [4, 3, 2], |
| [4, 2, 3], |
| [3, 2, 4], |
| [4, 2, 3], |
| ]]) |
| labels = np.array([[[1, 1, 0], |
| [1, 1, 0], |
| [1, 0, 0], |
| [1, 0, 0], |
| [0, 0, 0] |
| ]]) |
|
|
| batch_size = 8 |
| logits = jnp.tile(logits, [batch_size, 1, 1]) |
| labels = jnp.tile(labels, [batch_size, 1, 1]) |
|
|
| recall = model_utils.weighted_recall(logits, labels) |
| recall_expected = np.array([[1/2, 1., 1., 0., 0.]] * batch_size) |
| self.assertSequenceAlmostEqual( |
| recall.flatten(), recall_expected.flatten()) |
|
|
| @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| def test_weighted_top_one_correctly_classified(self, label_multiplier, |
| weight_multiplier): |
| """Tests the top1 correct computation.""" |
| batch_size = 512 |
| num_of_classes = 100 |
| logits = jnp.array(np.random.normal( |
| size=(batch_size, 50, num_of_classes)), dtype=jnp.float32) |
| labels = jnp.array(np.random.randint( |
| 0, 2, size=(batch_size, 50, num_of_classes))) |
| labels *= label_multiplier |
|
|
| weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| weights *= weight_multiplier |
|
|
| is_correct_array = model_utils.weighted_top_one_correctly_classified( |
| logits, labels, weights=weights) |
| num_correct = jnp.sum(is_correct_array) |
| is_correct_array_ref = model_utils.weighted_topk_correctly_classified( |
| logits, labels, weights, k=1) |
|
|
| np.testing.assert_array_almost_equal( |
| is_correct_array, is_correct_array_ref) |
| np.testing.assert_equal(np.sum(is_correct_array), |
| np.sum(is_correct_array_ref)) |
|
|
| self.is_valid(num_correct, 'Number of correctly classified') |
|
|
| @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| def test_weighted_unnormalized_sigmoid_cross_entropy(self, label_multiplier, |
| weight_multiplier): |
| """Tests the unnormalized sigmoid cross entropy computation.""" |
| batch_size = 512 |
| num_of_classes = 100 |
| logits = jnp.array( |
| np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| labels = jnp.array(np.random.randint(0, 2, |
| size=(batch_size, num_of_classes))) |
| labels *= label_multiplier |
|
|
| weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| weights *= weight_multiplier |
|
|
| loss_array = model_utils.weighted_unnormalized_sigmoid_cross_entropy( |
| logits, labels, weights=weights) |
| loss_sum = jnp.sum(loss_array) |
|
|
| self.is_valid(loss_sum, 'Loss value') |
|
|
| @parameterized.parameters(itertools.product([1., 0.], [1., 0.])) |
| def test_weighted_unnormalized_softmax_cross_entropy(self, label_multiplier, |
| weight_multiplier): |
| """Tests the unnormalized softmax cross entropy computation.""" |
| batch_size = 512 |
| num_of_classes = 100 |
| logits = jnp.array( |
| np.random.normal(size=(batch_size, num_of_classes)), dtype=jnp.float32) |
| labels = jnp.array( |
| np.random.randint(0, 2, size=(batch_size, num_of_classes))) |
| labels *= label_multiplier |
|
|
| weights = jnp.ones(shape=(batch_size,), dtype=jnp.float32) |
| weights *= weight_multiplier |
|
|
| loss_array = model_utils.weighted_unnormalized_softmax_cross_entropy( |
| logits, labels, weights=weights) |
| loss_sum = jnp.sum(loss_array) |
|
|
| self.is_valid(loss_sum, 'Loss value') |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|