llm-bradley-terry / hdinterval.py
jerome-white's picture
Lift HDI calculation to its own module
75b2724
raw
history blame
2.21 kB
import math
import warnings
import operator as op
import itertools as it
import functools as ft
import statistics as st
from dataclasses import dataclass
@dataclass
class HDI:
lower: float
upper: float
def __iter__(self):
yield from (self.lower, self.upper)
def __contains__(self, item):
return self.lower <= item <= self.upper
def width(self):
return self.upper - self.lower
class HDInterval:
@ft.cached_property
def values(self):
view = sorted(filter(math.isfinite, self._values))
if not view:
raise AttributeError('Empty data set')
return view
def __init__(self, values):
self._values = values
#
# See https://cran.r-project.org/package=HDInterval
#
def __call__(self, ci=0.95):
if ci == 1:
args = (self.values[x] for x in (0, -1))
else:
n = len(self.values)
exclude = n - math.floor(n * ci)
left = it.islice(self.values, exclude)
right = it.islice(self.values, n - exclude, None)
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
(*args, _) = min(diffs, key=op.itemgetter(-1))
return HDI(*args)
def _at(self, target, tolerance, ci, jump):
if ci > 1:
return 1
interval = self(ci)
if any(math.isclose(x, target, abs_tol=tolerance) for x in interval):
return ci
plus_minus = op.sub if target in interval else op.add
ci = plus_minus(ci, jump)
jump /= 2
return self._at(target, tolerance, ci, jump)
def at(self, target, tolerance=1e-3):
while tolerance < 1:
try:
return self._at(target, tolerance, 1, 1)
except RecursionError:
tolerance *= 10
warnings.warn(f'Tolerance reduced: {tolerance}')
raise OverflowError()
if __name__ == '__main__':
import numpy as np
data = np.random.uniform(size=2000)
# data = list(filter(lambda x: x > 0.7, data))
# data = [0.5] * 10
interval = HDInterval(data)
point = interval.at(0.5)
hdi = interval(point)
print(point, hdi)