jordyvl commited on
Commit
897eeff
1 Parent(s): 6a39113

Fix to reliability diagram - correct with test

Browse files
Files changed (2) hide show
  1. app.py +26 -30
  2. ece.py +23 -8
app.py CHANGED
@@ -48,34 +48,31 @@ def reliability_plot(results):
48
  # DEV: nicer would be to plot like a polygon
49
  # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
50
 
51
- def over_under_confidence(results):
52
- colors = []
53
- for j, bin in enumerate(results["y_bar"]):
54
- perfect = results["y_bar"][j]
55
- empirical = results["p_bar"][j]
56
-
57
- bin_color = (
58
- "limegreen"
59
- if np.allclose(perfect, empirical)
60
- else "dodgerblue"
61
- if empirical < perfect
62
- else "orangered"
63
- )
64
- colors.append(bin_color)
65
- return colors
66
-
67
  fig, ax1, ax2 = default_plot()
68
 
69
  # Bin differences
70
  bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
71
- B, bins, patches = ax1.hist(
72
- results["y_bar"],
73
- weights=np.nan_to_num(results["p_bar"][:-1], copy=True, nan=0),
74
- bins=bins_with_left_edge,
 
 
 
 
 
75
  )
76
- colors = over_under_confidence(results)
77
- for b in range(len(B)):
78
- patches[b].set_facecolor(colors[b]) # color based on over/underconfidence
 
 
 
 
 
 
 
 
79
 
80
  ax1handles = [
81
  mpatches.Patch(color="orangered", label="Overconfident"),
@@ -84,12 +81,11 @@ def reliability_plot(results):
84
  ]
85
 
86
  # Bin frequencies
87
- anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0]
88
- n_bins = len(results["y_bar"])
89
- bin_freqs = np.zeros(n_bins)
90
  bin_freqs[anindices] = results["bin_freq"]
91
- B, newbins, patches = ax2.hist(
92
- results["y_bar"], weights=bin_freqs, color="midnightblue", bins=bins_with_left_edge
93
  )
94
 
95
  acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
@@ -148,8 +144,8 @@ component = gr.inputs.Dataframe(
148
  )
149
 
150
  component.value = [
151
- [[0.63, 0.2, 0.2], 0],
152
- [[0.73, 0.1, 0.2], 2],
153
  [[0, 0.95, 0.05], 1],
154
  ]
155
  sample_data = [[component] + slider_defaults]
 
48
  # DEV: nicer would be to plot like a polygon
49
  # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  fig, ax1, ax2 = default_plot()
52
 
53
  # Bin differences
54
  bins_with_left_edge = np.insert(results["y_bar"], 0, 0, axis=0)
55
+ bins_with_right_edge = np.insert(results["y_bar"], -1, 1.0, axis=0)
56
+ bins_with_leftright_edge = np.insert(bins_with_left_edge, -1, 1.0, axis=0)
57
+ weights = np.nan_to_num(results["p_bar"], copy=True, nan=0)
58
+
59
+ # NOTE: the histogram API is strange
60
+ _, _, patches = ax1.hist(
61
+ bins_with_left_edge,
62
+ weights=weights,
63
+ bins=bins_with_leftright_edge,
64
  )
65
+ for b in range(len(patches)):
66
+ perfect = bins_with_right_edge[b] # if b != n_bins else
67
+ empirical = weights[b] # patches[b]._height
68
+ bin_color = (
69
+ "limegreen"
70
+ if perfect == empirical
71
+ else "dodgerblue"
72
+ if empirical < perfect
73
+ else "orangered"
74
+ )
75
+ patches[b].set_facecolor(bin_color) # color based on over/underconfidence
76
 
77
  ax1handles = [
78
  mpatches.Patch(color="orangered", label="Overconfident"),
 
81
  ]
82
 
83
  # Bin frequencies
84
+ anindices = np.where(~np.isnan(results["p_bar"]))[0]
85
+ bin_freqs = np.zeros(len(results["p_bar"]))
 
86
  bin_freqs[anindices] = results["bin_freq"]
87
+ ax2.hist(
88
+ bins_with_left_edge, weights=bin_freqs, color="midnightblue", bins=bins_with_leftright_edge
89
  )
90
 
91
  acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
 
144
  )
145
 
146
  component.value = [
147
+ [[0.6, 0.2, 0.2], 0],
148
+ [[0.7, 0.1, 0.2], 2],
149
  [[0, 0.95, 0.05], 1],
150
  ]
151
  sample_data = [[component] + slider_defaults]
ece.py CHANGED
@@ -21,7 +21,6 @@ import numpy as np
21
  from typing import Dict, Optional
22
 
23
 
24
-
25
  # TODO: Add BibTeX citation
26
  _CITATION = """\
27
  @InProceedings{huggingface:module,
@@ -103,9 +102,9 @@ def create_bins(n_bins=10, scheme="equal-range", bin_range=None, P=None):
103
  # rightmost entry per equal size group
104
  for cur_group in range(n_bins - 1):
105
  bin_upper_edges += [max(groups[cur_group])]
106
- bin_upper_edges += [1.01] #[np.inf] # always +1 for right edges
107
  bins = np.array(bin_upper_edges)
108
- #OverflowError: cannot convert float infinity to integer
109
 
110
  return bins
111
 
@@ -200,7 +199,14 @@ def top_1_CE(Y, P, **kwargs):
200
  )
201
  CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
202
  if kwargs["detail"]:
203
- return {"ECE": CE[0], "y_bar": CE[1], "p_bar": CE[2], "bin_freq": CE[3], "p_bar_cont": np.mean(p_max,-1), "accuracy": np.mean(y_correct)}
 
 
 
 
 
 
 
204
  return CE
205
 
206
 
@@ -306,9 +312,18 @@ def test_ECE():
306
  print(f"ECE: {res['ECE']}")
307
 
308
  res = ECE()._compute(predictions, references, detail=True)
309
- import pdb; pdb.set_trace() # breakpoint 25274412 //
310
-
311
  print(f"ECE: {res['ECE']}")
312
 
313
- if __name__ == '__main__':
314
- test_ECE()
 
 
 
 
 
 
 
 
 
 
 
 
21
  from typing import Dict, Optional
22
 
23
 
 
24
  # TODO: Add BibTeX citation
25
  _CITATION = """\
26
  @InProceedings{huggingface:module,
 
102
  # rightmost entry per equal size group
103
  for cur_group in range(n_bins - 1):
104
  bin_upper_edges += [max(groups[cur_group])]
105
+ bin_upper_edges += [1.01] # [np.inf] # always +1 for right edges
106
  bins = np.array(bin_upper_edges)
107
+ # OverflowError: cannot convert float infinity to integer
108
 
109
  return bins
110
 
 
199
  )
200
  CE = CE_estimate(y_correct, p_max, bins=bins, proxy=kwargs["proxy"], detail=kwargs["detail"])
201
  if kwargs["detail"]:
202
+ return {
203
+ "ECE": CE[0],
204
+ "y_bar": CE[1],
205
+ "p_bar": CE[2],
206
+ "bin_freq": CE[3],
207
+ "p_bar_cont": np.mean(p_max, -1),
208
+ "accuracy": np.mean(y_correct),
209
+ }
210
  return CE
211
 
212
 
 
312
  print(f"ECE: {res['ECE']}")
313
 
314
  res = ECE()._compute(predictions, references, detail=True)
 
 
315
  print(f"ECE: {res['ECE']}")
316
 
317
+
318
+ def test_deterministic():
319
+ res = ECE()._compute(
320
+ references=[0, 1, 2],
321
+ predictions=[[0.63, 0.2, 0.2], [0, 0.95, 0.05], [0.72, 0.1, 0.2]],
322
+ detail=True,
323
+ )
324
+ print(f"ECE: {res['ECE']}\n {res}")
325
+
326
+
327
+ if __name__ == "__main__":
328
+ test_deterministic()
329
+ test_ECE()