|
"""ECE metric file.""" |
|
|
|
from __future__ import annotations |
|
|
|
from typing import TYPE_CHECKING |
|
|
|
import datasets |
|
import evaluate |
|
from torch import LongTensor, Tensor |
|
from torchmetrics.functional.classification.calibration_error import ( |
|
binary_calibration_error, |
|
multiclass_calibration_error, |
|
) |
|
|
|
if TYPE_CHECKING: |
|
from collections.abc import Iterable |
|
|
|
_CITATION = """\ |
|
@InProceedings{huggingface:ece, |
|
title = {Expected calibration error (ECE)}, |
|
authors={Nathan Fradet}, |
|
year={2023} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
This metrics computes the expected calibration error (ECE). |
|
It directly calls the torchmetrics package: |
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html |
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates how good are predictions given some references, using certain scores |
|
Args: |
|
predictions: list of predictions to score. They must have a shape (N,C,...) if |
|
multiclass, or (N,...) if binary. |
|
references: list of reference for each prediction, with a shape (N,...). |
|
Returns: |
|
ece: expected calibration error |
|
Examples: |
|
>>> ece = evaluate.load("Natooz/ece") |
|
>>> results = ece.compute( |
|
... references=np.array([[0.25, 0.20, 0.55], |
|
... [0.55, 0.05, 0.40], |
|
... [0.10, 0.30, 0.60], |
|
... [0.90, 0.05, 0.05]]), |
|
... predictions=np.array(), |
|
... num_classes=3, |
|
... n_bins=3, |
|
... norm="l1", |
|
... ) |
|
>>> print(results) |
|
{'ece': 0.2000} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class ECE(evaluate.Metric): |
|
""" |
|
Module for the BinaryCalibrationError (ECE) metric of the torchmetrics package. |
|
|
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html. |
|
""" |
|
|
|
def _info(self) -> evaluate.MetricInfo: |
|
""" |
|
Return the module info. |
|
|
|
:return: module info. |
|
""" |
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Sequence(datasets.Value("float32")), |
|
"references": datasets.Value("int64"), |
|
} |
|
), |
|
|
|
homepage="https://huggingface.co/spaces/Natooz/ece", |
|
|
|
codebase_urls=[ |
|
"https://github.com/Lightning-AI/torchmetrics/blob/v0.11.4/src/torchmetrics/classification/calibration_error.py" |
|
], |
|
reference_urls=[ |
|
"https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html" |
|
], |
|
) |
|
|
|
def _compute( |
|
self, |
|
predictions: Iterable[float] | None = None, |
|
references: Iterable[int] | None = None, |
|
**kwargs |
|
) -> dict[str, float]: |
|
""" |
|
Return the Expected Calibration Error (ECE). |
|
|
|
See the torchmetrics documentation for more information on the method. |
|
https://torchmetrics.readthedocs.io/en/stable/classification/calibration_error.html |
|
predictions: (N,C,...) if multiclass or (N,...) if binary |
|
references: (N,...). |
|
|
|
If "num_classes" is not provided in a multiclass setting, the number maximum |
|
label index will be used as "num_classes". |
|
""" |
|
|
|
predictions = Tensor(predictions) |
|
references = LongTensor(references) |
|
|
|
|
|
error_msg = ( |
|
"Expected to have predictions with shape (N,C,...) for multiclass or " |
|
"(N,...) for binary, and references with shape (N,...), but got " |
|
f"{predictions.shape} and {references.shape}" |
|
) |
|
binary = True |
|
if predictions.dim() == references.dim() + 1: |
|
binary = False |
|
if "num_classes" not in kwargs: |
|
kwargs["num_classes"] = int(predictions.shape[1]) |
|
elif predictions.dim() == references.dim() and "num_classes" in kwargs: |
|
raise ValueError( |
|
"You gave the num_classes argument, with predictions and references " |
|
"having the same number of dimensions. " + error_msg |
|
) |
|
elif predictions.dim() != references.dim(): |
|
raise ValueError("Bad input shape. " + error_msg) |
|
|
|
|
|
if binary: |
|
ece = binary_calibration_error(predictions, references, **kwargs) |
|
else: |
|
ece = multiclass_calibration_error(predictions, references, **kwargs) |
|
return { |
|
"ece": float(ece), |
|
} |
|
|