ece / ece.py
Nathan Fradet
ruff formatting + changing gradio app loading
9c80799 unverified
"""ECE metric file."""
from __future__ import annotations
from typing import TYPE_CHECKING
import datasets
import evaluate
from torch import LongTensor, Tensor
from torchmetrics.functional.classification.calibration_error import (
binary_calibration_error,
multiclass_calibration_error,
)
if TYPE_CHECKING:
from collections.abc import Iterable
_CITATION = """\
@InProceedings{huggingface:ece,
title = {Expected calibration error (ECE)},
authors={Nathan Fradet},
year={2023}
}
"""
_DESCRIPTION = """\
This metrics computes the expected calibration error (ECE).
It directly calls the torchmetrics package:
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
"""
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions to score. They must have a shape (N,C,...) if
multiclass, or (N,...) if binary.
references: list of reference for each prediction, with a shape (N,...).
Returns:
ece: expected calibration error
Examples:
>>> ece = evaluate.load("Natooz/ece")
>>> results = ece.compute(
... references=np.array([[0.25, 0.20, 0.55],
... [0.55, 0.05, 0.40],
... [0.10, 0.30, 0.60],
... [0.90, 0.05, 0.05]]),
... predictions=np.array(),
... num_classes=3,
... n_bins=3,
... norm="l1",
... )
>>> print(results)
{'ece': 0.2000}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ECE(evaluate.Metric):
"""
Module for the BinaryCalibrationError (ECE) metric of the torchmetrics package.
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html.
"""
def _info(self) -> evaluate.MetricInfo:
"""
Return the module info.
:return: module info.
"""
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("float32")),
"references": datasets.Value("int64"),
}
),
# Homepage of the module for documentation
homepage="https://huggingface.co/spaces/Natooz/ece",
# Additional links to the codebase or references
codebase_urls=[
"https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py"
],
reference_urls=[
"https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html"
],
)
def _compute(
self,
predictions: Iterable[float] | None = None,
references: Iterable[int] | None = None,
**kwargs
) -> dict[str, float]:
"""
Return the Expected Calibration Error (ECE).
See the torchmetrics documentation for more information on the method.
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html
predictions: (N,C,...) if multiclass or (N,...) if binary
references: (N,...).
If "num_classes" is not provided in a multiclass setting, the number maximum
label index will be used as "num_classes".
"""
# Convert the input
predictions = Tensor(predictions)
references = LongTensor(references)
# Determine number of classes / binary or multiclass
error_msg = (
"Expected to have predictions with shape (N,C,...) for multiclass or "
"(N,...) for binary, and references with shape (N,...), but got "
f"{predictions.shape} and {references.shape}"
)
binary = True
if predictions.dim() == references.dim() + 1: # multiclass
binary = False
if "num_classes" not in kwargs:
kwargs["num_classes"] = int(predictions.shape[1])
elif predictions.dim() == references.dim() and "num_classes" in kwargs:
raise ValueError(
"You gave the num_classes argument, with predictions and references "
"having the same number of dimensions. " + error_msg
)
elif predictions.dim() != references.dim():
raise ValueError("Bad input shape. " + error_msg)
# Compute the calibration
if binary:
ece = binary_calibration_error(predictions, references, **kwargs)
else:
ece = multiclass_calibration_error(predictions, references, **kwargs)
return {
"ece": float(ece),
}