fix input format conversion
Browse files
ece.py
CHANGED
@@ -110,13 +110,13 @@ class ECE(evaluate.Metric):
|
|
110 |
|
111 |
# Determine number of classes / binary or multiclass
|
112 |
binary = True
|
113 |
-
if
|
114 |
-
max_label = int(amax(references, list(range(references.dim()))))
|
115 |
-
if max_label > 1:
|
116 |
-
kwargs["num_classes"] = max_label
|
117 |
-
binary = False
|
118 |
-
elif kwargs["num_classes"] > 1:
|
119 |
binary = False
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# Compute the calibration
|
122 |
if binary:
|
|
|
110 |
|
111 |
# Determine number of classes / binary or multiclass
|
112 |
binary = True
|
113 |
+
if predictions.dim() == references.dim() + 1:
|
|
|
|
|
|
|
|
|
|
|
114 |
binary = False
|
115 |
+
if "num_classes" not in kwargs:
|
116 |
+
kwargs["num_classes"] = int(predictions.shape[1])
|
117 |
+
else:
|
118 |
+
raise ValueError("Bad input shape. Expected to have predictions with shape (N,C,...) and references"
|
119 |
+
f"with shape (N,...), but got {predictions.shape} and {references.shape}")
|
120 |
|
121 |
# Compute the calibration
|
122 |
if binary:
|