# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Object detection calibration metrics. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.python.ops import metrics_impl def _safe_div(numerator, denominator): """Divides two tensors element-wise, returning 0 if the denominator is <= 0. Args: numerator: A real `Tensor`. denominator: A real `Tensor`, with dtype matching `numerator`. Returns: 0 if `denominator` <= 0, else `numerator` / `denominator` """ t = tf.truediv(numerator, denominator) zero = tf.zeros_like(t, dtype=denominator.dtype) condition = tf.greater(denominator, zero) zero = tf.cast(zero, t.dtype) return tf.where(condition, t, zero) def _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name): """Calculates Expected Calibration Error from accumulated statistics.""" bin_accuracies = _safe_div(bin_true_sum, bin_counts) bin_confidences = _safe_div(bin_preds_sum, bin_counts) abs_bin_errors = tf.abs(bin_accuracies - bin_confidences) bin_weights = _safe_div(bin_counts, tf.reduce_sum(bin_counts)) return tf.reduce_sum(abs_bin_errors * bin_weights, name=name) def expected_calibration_error(y_true, y_pred, nbins=20): """Calculates Expected Calibration Error (ECE). ECE is a scalar summary statistic of calibration error. It is the sample-weighted average of the difference between the predicted and true probabilities of a positive detection across uniformly-spaced model confidences [0, 1]. See referenced paper for a thorough explanation. Reference: Guo, et. al, "On Calibration of Modern Neural Networks" Page 2, Expected Calibration Error (ECE). https://arxiv.org/pdf/1706.04599.pdf This function creates three local variables, `bin_counts`, `bin_true_sum`, and `bin_preds_sum` that are used to compute ECE. For estimation of the metric over a stream of data, the function creates an `update_op` operation that updates these variables and returns the ECE. Args: y_true: 1-D tf.int64 Tensor of binarized ground truth, corresponding to each prediction in y_pred. y_pred: 1-D tf.float32 tensor of model confidence scores in range [0.0, 1.0]. nbins: int specifying the number of uniformly-spaced bins into which y_pred will be bucketed. Returns: value_op: A value metric op that returns ece. update_op: An operation that increments the `bin_counts`, `bin_true_sum`, and `bin_preds_sum` variables appropriately and whose value matches `ece`. Raises: InvalidArgumentError: if y_pred is not in [0.0, 1.0]. """ bin_counts = metrics_impl.metric_variable( [nbins], tf.float32, name='bin_counts') bin_true_sum = metrics_impl.metric_variable( [nbins], tf.float32, name='true_sum') bin_preds_sum = metrics_impl.metric_variable( [nbins], tf.float32, name='preds_sum') with tf.control_dependencies([ tf.assert_greater_equal(y_pred, 0.0), tf.assert_less_equal(y_pred, 1.0), ]): bin_ids = tf.histogram_fixed_width_bins(y_pred, [0.0, 1.0], nbins=nbins) with tf.control_dependencies([bin_ids]): update_bin_counts_op = tf.assign_add( bin_counts, tf.to_float(tf.bincount(bin_ids, minlength=nbins))) update_bin_true_sum_op = tf.assign_add( bin_true_sum, tf.to_float(tf.bincount(bin_ids, weights=y_true, minlength=nbins))) update_bin_preds_sum_op = tf.assign_add( bin_preds_sum, tf.to_float(tf.bincount(bin_ids, weights=y_pred, minlength=nbins))) ece_update_op = _ece_from_bins( update_bin_counts_op, update_bin_true_sum_op, update_bin_preds_sum_op, name='update_op') ece = _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name='value') return ece, ece_update_op