|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for calibration_metrics.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from object_detection.metrics import calibration_metrics |
|
|
|
|
|
class CalibrationLibTest(tf.test.TestCase): |
|
|
|
@staticmethod |
|
def _get_calibration_placeholders(): |
|
"""Returns TF placeholders for y_true and y_pred.""" |
|
return (tf.placeholder(tf.int64, shape=(None)), |
|
tf.placeholder(tf.float32, shape=(None))) |
|
|
|
def test_expected_calibration_error_all_bins_filled(self): |
|
"""Test expected calibration error when all bins contain predictions.""" |
|
y_true, y_pred = self._get_calibration_placeholders() |
|
expected_ece_op, update_op = calibration_metrics.expected_calibration_error( |
|
y_true, y_pred, nbins=2) |
|
with self.test_session() as sess: |
|
metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES) |
|
sess.run(tf.variables_initializer(var_list=metrics_vars)) |
|
|
|
|
|
|
|
sess.run( |
|
update_op, |
|
feed_dict={ |
|
y_pred: np.array([0., 0.2, 0.4, 0.5, 1.0]), |
|
y_true: np.array([0, 0, 1, 0, 1]) |
|
}) |
|
actual_ece = 0.08 + 0.1 |
|
expected_ece = sess.run(expected_ece_op) |
|
self.assertAlmostEqual(actual_ece, expected_ece) |
|
|
|
def test_expected_calibration_error_all_bins_not_filled(self): |
|
"""Test expected calibration error when no predictions for one bin.""" |
|
y_true, y_pred = self._get_calibration_placeholders() |
|
expected_ece_op, update_op = calibration_metrics.expected_calibration_error( |
|
y_true, y_pred, nbins=2) |
|
with self.test_session() as sess: |
|
metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES) |
|
sess.run(tf.variables_initializer(var_list=metrics_vars)) |
|
|
|
|
|
|
|
sess.run( |
|
update_op, |
|
feed_dict={ |
|
y_pred: np.array([0., 0.2, 0.4]), |
|
y_true: np.array([0, 0, 1]) |
|
}) |
|
actual_ece = np.abs(0.2 - (1 / 3.)) |
|
expected_ece = sess.run(expected_ece_op) |
|
self.assertAlmostEqual(actual_ece, expected_ece) |
|
|
|
def test_expected_calibration_error_with_multiple_data_streams(self): |
|
"""Test expected calibration error when multiple data batches provided.""" |
|
y_true, y_pred = self._get_calibration_placeholders() |
|
expected_ece_op, update_op = calibration_metrics.expected_calibration_error( |
|
y_true, y_pred, nbins=2) |
|
with self.test_session() as sess: |
|
metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES) |
|
sess.run(tf.variables_initializer(var_list=metrics_vars)) |
|
|
|
|
|
sess.run( |
|
update_op, |
|
feed_dict={ |
|
y_pred: np.array([0., 0.2]), |
|
y_true: np.array([0, 0]) |
|
}) |
|
sess.run( |
|
update_op, |
|
feed_dict={ |
|
y_pred: np.array([0.4, 0.5]), |
|
y_true: np.array([1, 0]) |
|
}) |
|
sess.run( |
|
update_op, feed_dict={ |
|
y_pred: np.array([1.0]), |
|
y_true: np.array([1]) |
|
}) |
|
actual_ece = 0.08 + 0.1 |
|
expected_ece = sess.run(expected_ece_op) |
|
self.assertAlmostEqual(actual_ece, expected_ece) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|