jordyvl commited on
Commit
9d4511f
1 Parent(s): 2afab11

plt.hist might not be the right plotting device; overrides existing bins

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. local_app.py +44 -20
  3. tests.py +11 -6
README.md CHANGED
@@ -65,7 +65,7 @@ The module returns dictionary with a key value pair, e.g., {"ECE": 0.64}.
65
  <!---
66
  *Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
67
  -->
68
- ```
69
  N = 10 # N evaluation instances {(x_i,y_i)}_{i=1}^N
70
  K = 5 # K class problem
71
 
 
65
  <!---
66
  *Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
67
  -->
68
+ ```python
69
  N = 10 # N evaluation instances {(x_i,y_i)}_{i=1}^N
70
  K = 5 # K class problem
71
 
local_app.py CHANGED
@@ -11,11 +11,12 @@ from ece import ECE # loads local instead
11
 
12
 
13
  import matplotlib.pyplot as plt
 
14
 
15
  """
16
  import seaborn as sns
17
  sns.set_style('white')
18
- sns.set_context("paper", font_scale=1) # 2
19
  """
20
  # plt.rcParams['figure.figsize'] = [10, 7]
21
  plt.rcParams["figure.dpi"] = 300
@@ -61,6 +62,7 @@ metric = ECE()
61
  Switch inputs and compute_fn
62
  """
63
 
 
64
  def default_plot():
65
  fig = plt.figure()
66
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
@@ -85,7 +87,25 @@ def default_plot():
85
  plt.tight_layout()
86
  return fig
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def reliability_plot(results):
 
89
  fig = plt.figure()
90
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
91
  ax2 = plt.subplot2grid((3, 1), (2, 0))
@@ -96,36 +116,36 @@ def reliability_plot(results):
96
  results["y_bar"][-1],
97
  ] # np.linspace(0, 1, n_bins)
98
  # if upper edge then minus binsize; same for center [but half]
 
 
 
 
 
 
 
99
 
100
  ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
101
  ax1.plot(
102
  ranged,
103
  ranged,
104
- color="darkgreen",
105
  ls="dotted",
106
  label="Perfect",
107
  )
108
- # ax1.plot(results["y_bar"], results["y_bar"], color="darkblue", label="Perfect")
 
 
 
 
109
 
110
  anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0]
111
  bin_freqs = np.zeros(n_bins)
112
  bin_freqs[anindices] = results["bin_freq"]
113
- ax2.hist(results["y_bar"], results["y_bar"], weights=bin_freqs)
114
 
115
- # widths = np.diff(results["y_bar"])
116
- for j, bin in enumerate(results["y_bar"]):
117
- perfect = results["y_bar"][j]
118
- empirical = results["p_bar"][j]
119
 
120
- if np.isnan(empirical):
121
- continue
122
-
123
- #width=-ranged[j],
124
- ax1.bar([perfect], height=[empirical], align="edge", color="lightblue")
125
- """
126
- if perfect == empirical:
127
- continue
128
- """
129
  acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
130
  conf_plt = ax2.axvline(
131
  x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence"
@@ -134,14 +154,18 @@ def reliability_plot(results):
134
 
135
  # Bin differences
136
  ax1.set_ylabel("Conditional Expectation")
137
- ax1.set_ylim([-0.05, 1.05]) # respective to bin range
138
- ax1.legend(loc="lower right")
139
  ax1.set_title("Reliability Diagram")
 
 
140
 
141
  # Bin frequencies
142
  ax2.set_xlabel("Confidence")
143
  ax2.set_ylabel("Count")
144
  ax2.legend(loc="upper left") # , ncol=2
 
 
145
  plt.tight_layout()
146
  return fig
147
 
@@ -173,7 +197,7 @@ def compute_and_plot(data, n_bins, bin_range, scheme, proxy, p):
173
 
174
 
175
  outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
176
- #outputs[1].value = default_plot().__dict__
177
 
178
  iface = gr.Interface(
179
  fn=compute_and_plot,
 
11
 
12
 
13
  import matplotlib.pyplot as plt
14
+ import matplotlib.patches as mpatches
15
 
16
  """
17
  import seaborn as sns
18
  sns.set_style('white')
19
+ sns.set_context("paper", font_scale=1)
20
  """
21
  # plt.rcParams['figure.figsize'] = [10, 7]
22
  plt.rcParams["figure.dpi"] = 300
 
62
  Switch inputs and compute_fn
63
  """
64
 
65
+
66
  def default_plot():
67
  fig = plt.figure()
68
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
 
87
  plt.tight_layout()
88
  return fig
89
 
90
+
91
+ def over_under_confidence(results):
92
+ colors = []
93
+ for j, bin in enumerate(results["y_bar"]):
94
+ perfect = results["y_bar"][j]
95
+ empirical = results["p_bar"][j]
96
+ bin_color = (
97
+ "limegreen"
98
+ if perfect == empirical
99
+ else "dodgerblue"
100
+ if empirical < perfect
101
+ else "orangered"
102
+ )
103
+ colors.append(bin_color)
104
+ return colors
105
+
106
+
107
  def reliability_plot(results):
108
+ #DEV: might still need to write tests in case of equal mass binning
109
  fig = plt.figure()
110
  ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
111
  ax2 = plt.subplot2grid((3, 1), (2, 0))
 
116
  results["y_bar"][-1],
117
  ] # np.linspace(0, 1, n_bins)
118
  # if upper edge then minus binsize; same for center [but half]
119
+ # rwidth is dependent on the binning
120
+ B, bins, patches = ax1.hist(
121
+ results["y_bar"], weights=results["p_bar"][:-1] #rwidth=len(results["p_bar"]/len(results["p_bar"]-1 )) #, range=(0,1),
122
+ ) # , rwidth=1, align="right") #
123
+ colors = over_under_confidence(results)
124
+ for b in range(len(B)):
125
+ patches[b].set_facecolor(colors[b]) # color based on over/underconfidence
126
 
127
  ranged = np.linspace(bin_range[0], bin_range[1], n_bins)
128
  ax1.plot(
129
  ranged,
130
  ranged,
131
+ color="limegreen",
132
  ls="dotted",
133
  label="Perfect",
134
  )
135
+ ax1handles = [
136
+ mpatches.Patch(color="orangered", label="Overconfident"),
137
+ mpatches.Patch(color="limegreen", label="Perfect", linestyle="dotted"),
138
+ mpatches.Patch(color="dodgerblue", label="Underconfident"),
139
+ ]
140
 
141
  anindices = np.where(~np.isnan(results["p_bar"][:-1]))[0]
142
  bin_freqs = np.zeros(n_bins)
143
  bin_freqs[anindices] = results["bin_freq"]
144
+ ax2.hist(results["y_bar"], weights=bin_freqs, color="midnightblue") #bins=results["y_bar"],
145
 
146
+ # DEV: nicer would be to plot like a polygon
147
+ # see: https://github.com/markus93/fit-on-the-test/blob/main/Experiments_Synthetic/binnings.py
 
 
148
 
 
 
 
 
 
 
 
 
 
149
  acc_plt = ax2.axvline(x=results["accuracy"], ls="solid", lw=3, c="black", label="Accuracy")
150
  conf_plt = ax2.axvline(
151
  x=results["p_bar_cont"], ls="dotted", lw=3, c="#444", label="Avg. confidence"
 
154
 
155
  # Bin differences
156
  ax1.set_ylabel("Conditional Expectation")
157
+ ax1.set_ylim([0, 1.05]) # respective to bin range
158
+ ax1.legend(loc="lower right", handles=ax1handles)
159
  ax1.set_title("Reliability Diagram")
160
+ # ax1.set_xticks([0]+results["y_bar"])
161
+ ax1.set_xlim([-0.05, 1.05]) # respective to bin range
162
 
163
  # Bin frequencies
164
  ax2.set_xlabel("Confidence")
165
  ax2.set_ylabel("Count")
166
  ax2.legend(loc="upper left") # , ncol=2
167
+ # ax2.set_xticks([0, ]+results["y_bar"])
168
+ ax2.set_xlim([-0.05, 1.05]) # respective to bin range
169
  plt.tight_layout()
170
  return fig
171
 
 
197
 
198
 
199
  outputs = [gr.outputs.Textbox(label="ECE"), gr.Plot(label="Reliability diagram")]
200
+ # outputs[1].value = default_plot().__dict__
201
 
202
  iface = gr.Interface(
203
  fn=compute_and_plot,
tests.py CHANGED
@@ -1,12 +1,17 @@
1
- import numpy as np
2
-
3
  test_cases = [
4
  {"predictions": [[0, 1], [1, 0]], "references": [1, 0], "result": {"ECE": 0}},
5
  {"predictions": [[0, 1], [1, 0]], "references": [0, 1], "result": {"ECE": 1}},
6
  {
7
- "predictions": [[0, 0.1, 0.9], [0.2, 0.8, 0]],
8
- "references": [2, 0], # kwargs?
9
- "result": {"ECE": >0<1},
 
 
 
 
 
 
 
10
  },
11
- ]
12
 
 
 
 
 
1
  test_cases = [
2
  {"predictions": [[0, 1], [1, 0]], "references": [1, 0], "result": {"ECE": 0}},
3
  {"predictions": [[0, 1], [1, 0]], "references": [0, 1], "result": {"ECE": 1}},
4
  {
5
+ "predictions": [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.75, 0.05 ,0.2]],
6
+ "references": [0, 1, 2],
7
+ "result": {"ECE": ((abs((0==0)-0.7) + abs((1==1)-1) + abs((2==0)-0.8))/3)},
8
+ #all predictions in separate bins
9
+ },
10
+ {
11
+ "predictions": [[0.6, 0.2, 0.2], [0, 0.95, 0.05], [0.7, 0.1 ,0.2]],
12
+ "references": [0, 1, 2],
13
+ "result": {"ECE": abs((0==0)-0.7 + (2==0)-0.7)/3 + abs((1==1)-1)/3},
14
+ #some predictions in same bin
15
  },
 
16
 
17
+ # DEV: make more advanced tests including differing kwargs