jerome-white commited on
Commit
43baf6b
1 Parent(s): 0bb04ff

Rework of the comparison plot

Browse files
Files changed (1) hide show
  1. app.py +42 -39
app.py CHANGED
@@ -12,6 +12,7 @@ import seaborn as sns
12
  import matplotlib.pyplot as plt
13
  from datasets import load_dataset
14
  from scipy.special import expit
 
15
 
16
  from hdinterval import HDI, HDInterval
17
 
@@ -134,12 +135,6 @@ class RankPlotter(DataPlotter):
134
  class ComparisonPlotter(DataPlotter):
135
  _uncertain = 0.5
136
 
137
- @staticmethod
138
- def to_relative(hdi, ax):
139
- (lhs, rhs) = ax.get_xlim()
140
- length = rhs - lhs
141
- yield from (abs(lhs - x) / length for x in hdi)
142
-
143
  def __init__(self, df, model_1, model_2, ci):
144
  super().__init__(compare(df, model_1, model_2))
145
  self.interval = HDInterval(self.df)
@@ -147,41 +142,49 @@ class ComparisonPlotter(DataPlotter):
147
 
148
  def draw(self, ax):
149
  hdi = self.interval(self.ci)
 
150
 
151
- ax = sns.histplot(self.df, stat='density')
152
-
153
- top = max(x.get_height() for x in ax.patches)
154
- y = top * 1.05
155
-
156
- (_, color, *_) = sns.color_palette()
157
- (xmin, xmax) = self.to_relative(hdi, ax)
158
- linestyle = 'dashed' if self._uncertain in hdi else 'solid'
159
- ax.axhline(y=y,
160
- xmin=xmin,
161
- xmax=xmax,
162
- linestyle=linestyle,
163
- color=color)
164
- ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
165
-
166
- x = (hdi.lower + hdi.upper) / 2
167
- ax.text(x=x,
168
- y=y,
169
- s=f'{self.ci:.0%} HDI',
170
- backgroundcolor='white',
171
- horizontalalignment='center',
172
- verticalalignment='center')
173
 
 
174
  try:
175
  ci_min = self.interval.at(self._uncertain)
176
- ax.text(x=0.01,
177
- y=0.99,
178
- s=f'0.5 \u2248\u2208 {ci_min:.0%} HDI',
 
 
179
  horizontalalignment='left',
180
  verticalalignment='top',
181
  transform=ax.transAxes)
182
  except ArithmeticError:
183
  pass
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  #
186
  #
187
  #
@@ -258,17 +261,17 @@ def layout(tab):
258
 
259
  with gr.Row():
260
  with gr.Column():
261
- gr.Markdown('''
262
 
263
  Probability that Model 1 is preferred to Model 2. The
264
  histogram is represents the distribution of inverse
265
- logit of the difference in model abilities. The
266
- horizontal line above the histogram marks the chosen
267
- [highest density
268
- interval](https://cran.r-project.org/package=HDInterval). The
269
- line is dashed if the interval overlaps 0.5, solid
270
- otherwise. The HDI in the upper left denotes the
271
- smallest approximate HDI that is inclusive of 0.5.
272
 
273
  ''')
274
  with gr.Column():
 
12
  import matplotlib.pyplot as plt
13
  from datasets import load_dataset
14
  from scipy.special import expit
15
+ from matplotlib.ticker import FixedLocator, StrMethodFormatter
16
 
17
  from hdinterval import HDI, HDInterval
18
 
 
135
  class ComparisonPlotter(DataPlotter):
136
  _uncertain = 0.5
137
 
 
 
 
 
 
 
138
  def __init__(self, df, model_1, model_2, ci):
139
  super().__init__(compare(df, model_1, model_2))
140
  self.interval = HDInterval(self.df)
 
142
 
143
  def draw(self, ax):
144
  hdi = self.interval(self.ci)
145
+ (c_hist, c_hdi) = sns.color_palette('colorblind', n_colors=2)
146
 
147
+ ax = sns.histplot(data=self.df,
148
+ stat='density',
149
+ color=c_hist)
150
+ ax.set_xlabel('logit$^{-1}$(\u03B1$_{1}$ - \u03B1$_{2}$)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ self.pr(ax, hdi, c_hdi)
153
  try:
154
  ci_min = self.interval.at(self._uncertain)
155
+ ax.text(x=0.025,
156
+ y=0.975,
157
+ s=f'{self._uncertain} \u2208 {ci_min:.0%} HDI',
158
+ fontsize='small',
159
+ fontstyle='italic',
160
  horizontalalignment='left',
161
  verticalalignment='top',
162
  transform=ax.transAxes)
163
  except ArithmeticError:
164
  pass
165
 
166
+ def pr(self, ax, hdi, color):
167
+ x = self.df.median()
168
+ zorder = ax.zorder - 1
169
+
170
+ (label, *_) = ax.get_xticklabels()
171
+ parts = label.get_text().split('.')
172
+ decimals = len(parts[-1]) + 1 if parts else 2
173
+ fmt = f'Pr(M$_{{{{1}}}}$ \u003E M$_{{{{2}}}}$) = {{x:.{decimals}f}}'
174
+
175
+ ax.axvline(x=x,
176
+ color=color,
177
+ linestyle='dashed')
178
+ ax.axvspan(xmin=hdi.lower,
179
+ xmax=hdi.upper,
180
+ alpha=0.15,
181
+ color=color,
182
+ zorder=zorder)
183
+
184
+ ax_ = ax.secondary_xaxis('top')
185
+ ax_.xaxis.set_major_locator(FixedLocator([x]))
186
+ ax_.xaxis.set_major_formatter(StrMethodFormatter(fmt))
187
+
188
  #
189
  #
190
  #
 
261
 
262
  with gr.Row():
263
  with gr.Column():
264
+ gr.Markdown(f'''
265
 
266
  Probability that Model 1 is preferred to Model 2. The
267
  histogram is represents the distribution of inverse
268
+ logit of the difference in model abilities. The dashed
269
+ vertical line is its median. The shaded region
270
+ demarcates the chosen [highest density
271
+ interval](https://cran.r-project.org/package=HDInterval)
272
+ (HDI). The note in the upper left denotes the smallest
273
+ HDI that is inclusive of
274
+ {ComparisonPlotter._uncertain}.
275
 
276
  ''')
277
  with gr.Column():