Natooz commited on
Commit
3a718e9
1 Parent(s): 65b297d

fix input format conversion

Browse files
Files changed (1) hide show
  1. ece.py +6 -6
ece.py CHANGED
@@ -110,13 +110,13 @@ class ECE(evaluate.Metric):
110
 
111
  # Determine number of classes / binary or multiclass
112
  binary = True
113
- if "num_classes" not in kwargs:
114
- max_label = int(amax(references, list(range(references.dim()))))
115
- if max_label > 1:
116
- kwargs["num_classes"] = max_label
117
- binary = False
118
- elif kwargs["num_classes"] > 1:
119
  binary = False
 
 
 
 
 
120
 
121
  # Compute the calibration
122
  if binary:
 
110
 
111
  # Determine number of classes / binary or multiclass
112
  binary = True
113
+ if predictions.dim() == references.dim() + 1:
 
 
 
 
 
114
  binary = False
115
+ if "num_classes" not in kwargs:
116
+ kwargs["num_classes"] = int(predictions.shape[1])
117
+ else:
118
+ raise ValueError("Bad input shape. Expected to have predictions with shape (N,C,...) and references"
119
+ f"with shape (N,...), but got {predictions.shape} and {references.shape}")
120
 
121
  # Compute the calibration
122
  if binary: