Spaces:
Sleeping
Sleeping
File size: 2,212 Bytes
75b2724 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import math
import warnings
import operator as op
import itertools as it
import functools as ft
import statistics as st
from dataclasses import dataclass
@dataclass
class HDI:
lower: float
upper: float
def __iter__(self):
yield from (self.lower, self.upper)
def __contains__(self, item):
return self.lower <= item <= self.upper
def width(self):
return self.upper - self.lower
class HDInterval:
@ft.cached_property
def values(self):
view = sorted(filter(math.isfinite, self._values))
if not view:
raise AttributeError('Empty data set')
return view
def __init__(self, values):
self._values = values
#
# See https://cran.r-project.org/package=HDInterval
#
def __call__(self, ci=0.95):
if ci == 1:
args = (self.values[x] for x in (0, -1))
else:
n = len(self.values)
exclude = n - math.floor(n * ci)
left = it.islice(self.values, exclude)
right = it.islice(self.values, n - exclude, None)
diffs = ((x, y, y - x) for (x, y) in zip(left, right))
(*args, _) = min(diffs, key=op.itemgetter(-1))
return HDI(*args)
def _at(self, target, tolerance, ci, jump):
if ci > 1:
return 1
interval = self(ci)
if any(math.isclose(x, target, abs_tol=tolerance) for x in interval):
return ci
plus_minus = op.sub if target in interval else op.add
ci = plus_minus(ci, jump)
jump /= 2
return self._at(target, tolerance, ci, jump)
def at(self, target, tolerance=1e-3):
while tolerance < 1:
try:
return self._at(target, tolerance, 1, 1)
except RecursionError:
tolerance *= 10
warnings.warn(f'Tolerance reduced: {tolerance}')
raise OverflowError()
if __name__ == '__main__':
import numpy as np
data = np.random.uniform(size=2000)
# data = list(filter(lambda x: x > 0.7, data))
# data = [0.5] * 10
interval = HDInterval(data)
point = interval.at(0.5)
hdi = interval(point)
print(point, hdi)
|