jerome-white commited on
Commit
75b2724
1 Parent(s): db4a8ee

Lift HDI calculation to its own module

Browse files

New functionality from HDI allows the calculation of the smallest
interval that excludes a given value. Its added complexity adds code,
which makes putting it in its own module cleaner.

Files changed (2) hide show
  1. app.py +14 -27
  2. hdinterval.py +87 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import itertools as it
4
  import functools as ft
5
  import collections as cl
6
  from pathlib import Path
 
7
 
8
  import pandas as pd
9
  import gradio as gr
@@ -12,27 +13,9 @@ import matplotlib.pyplot as plt
12
  from datasets import load_dataset
13
  from scipy.special import expit
14
 
15
- HDI = cl.namedtuple('HDI', 'lower, upper')
16
- TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset')
17
-
18
- #
19
- # See https://cran.r-project.org/package=HDInterval
20
- #
21
- def hdi(values, ci=0.95):
22
- values = sorted(filter(math.isfinite, values))
23
- if not values:
24
- raise ValueError('Empty data set')
25
-
26
- n = len(values)
27
- exclude = n - math.floor(n * ci)
28
 
29
- left = it.islice(values, exclude)
30
- right = it.islice(values, n - exclude, None)
31
-
32
- diffs = ((x, y, y - x) for (x, y) in zip(left, right))
33
- (*args, _) = min(diffs, key=op.itemgetter(-1))
34
-
35
- return HDI(*args)
36
 
37
  #
38
  #
@@ -60,14 +43,15 @@ def load(repo):
60
  def summarize(df, ci=0.95):
61
  def _aggregate(i, g):
62
  values = g['value']
63
- interval = hdi(values, ci)
 
64
 
65
  agg = {
66
  'model': i,
67
  'ability': values.median(),
68
- 'uncertainty': interval.upper - interval.lower,
69
  }
70
- agg.update(interval._asdict())
71
 
72
  return agg
73
 
@@ -150,17 +134,20 @@ class RankPlotter(DataPlotter):
150
  class ComparisonPlotter(DataPlotter):
151
  def __init__(self, df, model_1, model_2, ci=0.95):
152
  super().__init__(compare(df, model_1, model_2))
153
- self.interval = hdi(self.df, ci)
 
154
 
155
  def draw(self, ax):
 
 
156
  sns.ecdfplot(self.df, ax=ax)
157
 
158
  (_, color, *_) = sns.color_palette()
159
  ax.axvline(x=self.df.median(),
160
  color=color,
161
  linestyle='dashed')
162
- ax.axvspan(xmin=self.interval.lower,
163
- xmax=self.interval.upper,
164
  alpha=0.15,
165
  color=color)
166
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
@@ -205,7 +192,7 @@ def layout(tab):
205
 
206
  with gr.Row():
207
  view = rank(summarize(df), False)
208
- columns = { x: f'HDI {x}' for x in HDI._fields }
209
  for i in view.columns:
210
  columns.setdefault(i, i.title())
211
  view = (view
 
4
  import functools as ft
5
  import collections as cl
6
  from pathlib import Path
7
+ from dataclasses import fields, asdict
8
 
9
  import pandas as pd
10
  import gradio as gr
 
13
  from datasets import load_dataset
14
  from scipy.special import expit
15
 
16
+ from hdinterval import HDI, HDInterval
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ TabGroup = cl.namedtuple('TabGroup', 'name, docs, dataset')
 
 
 
 
 
 
19
 
20
  #
21
  #
 
43
  def summarize(df, ci=0.95):
44
  def _aggregate(i, g):
45
  values = g['value']
46
+ hdi = HDInterval(values)
47
+ interval = hdi(ci)
48
 
49
  agg = {
50
  'model': i,
51
  'ability': values.median(),
52
+ 'uncertainty': interval.width(),
53
  }
54
+ agg.update(asdict(interval))
55
 
56
  return agg
57
 
 
134
  class ComparisonPlotter(DataPlotter):
135
  def __init__(self, df, model_1, model_2, ci=0.95):
136
  super().__init__(compare(df, model_1, model_2))
137
+ self.hdi = HDInterval(self.df)
138
+ self.ci = ci
139
 
140
  def draw(self, ax):
141
+ interval = self.hdi(self.ci)
142
+
143
  sns.ecdfplot(self.df, ax=ax)
144
 
145
  (_, color, *_) = sns.color_palette()
146
  ax.axvline(x=self.df.median(),
147
  color=color,
148
  linestyle='dashed')
149
+ ax.axvspan(xmin=interval.lower,
150
+ xmax=interval.upper,
151
  alpha=0.15,
152
  color=color)
153
  ax.set_xlabel('Pr(M$_{1}$ \u003E M$_{2}$)')
 
192
 
193
  with gr.Row():
194
  view = rank(summarize(df), False)
195
+ columns = { x.name: f'HDI {x.name}' for x in fields(HDI) }
196
  for i in view.columns:
197
  columns.setdefault(i, i.title())
198
  view = (view
hdinterval.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ import operator as op
4
+ import itertools as it
5
+ import functools as ft
6
+ import statistics as st
7
+ from dataclasses import dataclass
8
+
9
+ @dataclass
10
+ class HDI:
11
+ lower: float
12
+ upper: float
13
+
14
+ def __iter__(self):
15
+ yield from (self.lower, self.upper)
16
+
17
+ def __contains__(self, item):
18
+ return self.lower <= item <= self.upper
19
+
20
+ def width(self):
21
+ return self.upper - self.lower
22
+
23
+ class HDInterval:
24
+ @ft.cached_property
25
+ def values(self):
26
+ view = sorted(filter(math.isfinite, self._values))
27
+ if not view:
28
+ raise AttributeError('Empty data set')
29
+
30
+ return view
31
+
32
+ def __init__(self, values):
33
+ self._values = values
34
+
35
+ #
36
+ # See https://cran.r-project.org/package=HDInterval
37
+ #
38
+ def __call__(self, ci=0.95):
39
+ if ci == 1:
40
+ args = (self.values[x] for x in (0, -1))
41
+ else:
42
+ n = len(self.values)
43
+ exclude = n - math.floor(n * ci)
44
+
45
+ left = it.islice(self.values, exclude)
46
+ right = it.islice(self.values, n - exclude, None)
47
+
48
+ diffs = ((x, y, y - x) for (x, y) in zip(left, right))
49
+ (*args, _) = min(diffs, key=op.itemgetter(-1))
50
+
51
+ return HDI(*args)
52
+
53
+ def _at(self, target, tolerance, ci, jump):
54
+ if ci > 1:
55
+ return 1
56
+
57
+ interval = self(ci)
58
+ if any(math.isclose(x, target, abs_tol=tolerance) for x in interval):
59
+ return ci
60
+
61
+ plus_minus = op.sub if target in interval else op.add
62
+ ci = plus_minus(ci, jump)
63
+ jump /= 2
64
+
65
+ return self._at(target, tolerance, ci, jump)
66
+
67
+ def at(self, target, tolerance=1e-3):
68
+ while tolerance < 1:
69
+ try:
70
+ return self._at(target, tolerance, 1, 1)
71
+ except RecursionError:
72
+ tolerance *= 10
73
+ warnings.warn(f'Tolerance reduced: {tolerance}')
74
+
75
+ raise OverflowError()
76
+
77
+ if __name__ == '__main__':
78
+ import numpy as np
79
+
80
+ data = np.random.uniform(size=2000)
81
+ # data = list(filter(lambda x: x > 0.7, data))
82
+ # data = [0.5] * 10
83
+
84
+ interval = HDInterval(data)
85
+ point = interval.at(0.5)
86
+ hdi = interval(point)
87
+ print(point, hdi)