File size: 2,212 Bytes
75b2724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import math
import warnings
import operator as op
import itertools as it
import functools as ft
import statistics as st
from dataclasses import dataclass

@dataclass
class HDI:
    lower: float
    upper: float

    def __iter__(self):
        yield from (self.lower, self.upper)

    def __contains__(self, item):
        return self.lower <= item <= self.upper

    def width(self):
        return self.upper - self.lower

class HDInterval:
    @ft.cached_property
    def values(self):
        view = sorted(filter(math.isfinite, self._values))
        if not view:
            raise AttributeError('Empty data set')

        return view

    def __init__(self, values):
        self._values = values

    #
    # See https://cran.r-project.org/package=HDInterval
    #
    def __call__(self, ci=0.95):
        if ci == 1:
            args = (self.values[x] for x in (0, -1))
        else:
            n = len(self.values)
            exclude = n - math.floor(n * ci)

            left = it.islice(self.values, exclude)
            right = it.islice(self.values, n - exclude, None)

            diffs = ((x, y, y - x) for (x, y) in zip(left, right))
            (*args, _) = min(diffs, key=op.itemgetter(-1))

        return HDI(*args)

    def _at(self, target, tolerance, ci, jump):
        if ci > 1:
            return 1

        interval = self(ci)
        if any(math.isclose(x, target, abs_tol=tolerance) for x in interval):
            return ci

        plus_minus = op.sub if target in interval else op.add
        ci = plus_minus(ci, jump)
        jump /= 2

        return self._at(target, tolerance, ci, jump)

    def at(self, target, tolerance=1e-3):
        while tolerance < 1:
            try:
                return self._at(target, tolerance, 1, 1)
            except RecursionError:
                tolerance *= 10
                warnings.warn(f'Tolerance reduced: {tolerance}')

        raise OverflowError()

if __name__ == '__main__':
    import numpy as np

    data = np.random.uniform(size=2000)
    # data = list(filter(lambda x: x > 0.7, data))
    # data = [0.5] * 10

    interval = HDInterval(data)
    point = interval.at(0.5)
    hdi = interval(point)
    print(point, hdi)