Spaces:
Running
Running
"""Confusion Matrix metric.""" | |
import datasets | |
import evaluate | |
from sklearn.metrics import confusion_matrix | |
_DESCRIPTION = """ | |
Compute confusion matrix to evaluate the accuracy of a classification. | |
By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}` | |
is equal to the number of observations known to be in group :math:`i` and | |
predicted to be in group :math:`j`. | |
Thus in binary classification, the count of true negatives is | |
:math:`C_{0,0}`, false negatives is :math:`C_{1,0}`, true positives is | |
:math:`C_{1,1}` and false positives is :math:`C_{0,1}`. | |
Read more in the :ref:`User Guide <confusion_matrix>`. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
y_true : array-like of shape (n_samples,) | |
Ground truth (correct) target values. | |
y_pred : array-like of shape (n_samples,) | |
Estimated targets as returned by a classifier. | |
labels : array-like of shape (n_classes), default=None | |
List of labels to index the matrix. This may be used to reorder | |
or select a subset of labels. | |
If ``None`` is given, those that appear at least once | |
in ``y_true`` or ``y_pred`` are used in sorted order. | |
sample_weight : array-like of shape (n_samples,), default=None | |
Sample weights. | |
.. versionadded:: 0.18 | |
normalize : {'true', 'pred', 'all'}, default=None | |
Normalizes confusion matrix over the true (rows), predicted (columns) | |
conditions or all the population. If None, confusion matrix will not be | |
normalized. | |
Returns: | |
C : ndarray of shape (n_classes, n_classes) | |
Confusion matrix whose i-th row and j-th | |
column entry indicates the number of | |
samples with true label being i-th class | |
and predicted label being j-th class. | |
See Also: | |
ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix | |
given an estimator, the data, and the label. | |
ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix | |
given the true and predicted labels. | |
ConfusionMatrixDisplay : Confusion Matrix visualization. | |
References: | |
.. [1] `Wikipedia entry for the Confusion matrix | |
<https://en.wikipedia.org/wiki/Confusion_matrix>`_ | |
(Wikipedia and other references may use a different | |
convention for axes). | |
Examples: | |
>>> from sklearn.metrics import confusion_matrix | |
>>> y_true = [2, 0, 2, 2, 0, 1] | |
>>> y_pred = [0, 0, 2, 2, 0, 2] | |
>>> confusion_matrix(y_true, y_pred) | |
array([[2, 0, 0], | |
[0, 0, 1], | |
[1, 0, 2]]) | |
>>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"] | |
>>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"] | |
>>> confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"]) | |
array([[2, 0, 0], | |
[0, 0, 1], | |
[1, 0, 2]]) | |
In the binary case, we can extract true positives, etc as follows: | |
>>> tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel() | |
>>> (tn, fp, fn, tp) | |
(0, 2, 1, 1) | |
""" | |
_CITATION = """ | |
@article{scikit-learn, | |
title={Scikit-learn: Machine Learning in {P}ython}, | |
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V. | |
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P. | |
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and | |
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.}, | |
journal={Journal of Machine Learning Research}, | |
volume={12}, | |
pages={2825--2830}, | |
year={2011} | |
} | |
""" | |
class ConfusionMatrix(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Sequence(datasets.Value("int32")), | |
"references": datasets.Sequence(datasets.Value("int32")), | |
} | |
if self.config_name == "multilabel" | |
else { | |
"predictions": datasets.Value("int32"), | |
"references": datasets.Value("int32"), | |
} | |
), | |
reference_urls=[ | |
"https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html" | |
], | |
) | |
def _compute( | |
self, | |
predictions, | |
references, | |
*, | |
labels=None, | |
sample_weight=None, | |
normalize=None | |
): | |
return { | |
"confusion_matrix": confusion_matrix( | |
y_true=references, | |
y_pred=predictions, | |
labels=labels, | |
sample_weight=sample_weight, | |
normalize=normalize, | |
) | |
} | |