File size: 3,097 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import pytest

import ding
ding.enable_numba = False  # noqa
from ding.utils import SumSegmentTree, MinSegmentTree  # noqa


@pytest.mark.unittest
class TestSumSegmentTree:

    def test_create(self):
        with pytest.raises(AssertionError):
            tree = SumSegmentTree(capacity=13)

        tree = SumSegmentTree(capacity=16)
        assert (tree.operation == 'sum')
        assert (tree.neutral_element == 0.)
        assert (max(tree.value) == 0.)
        assert (min(tree.value) == 0.)

    def test_set_get_item(self):
        tree = SumSegmentTree(capacity=4)
        elements = [1, 5, 4, 7]
        get_result = []
        for idx, val in enumerate(elements):
            tree[idx] = val
            get_result.append(tree[idx])

        assert (elements == get_result)
        assert (tree.reduce() == sum(elements))
        assert (tree.reduce(0, 3) == sum(elements[:3]))
        assert (tree.reduce(0, 2) == sum(elements[:2]))
        assert (tree.reduce(0, 1) == sum(elements[:1]))
        assert (tree.reduce(1, 3) == sum(elements[1:3]))
        assert (tree.reduce(1, 2) == sum(elements[1:2]))
        assert (tree.reduce(2, 3) == sum(elements[2:3]))

        with pytest.raises(AssertionError):
            tree.reduce(2, 2)

    def test_find_prefixsum_idx(self):
        tree = SumSegmentTree(capacity=8)
        elements = [0, 0.1, 0.5, 0, 0, 0.2, 0.8, 0]
        for idx, val in enumerate(elements):
            tree[idx] = val
        with pytest.raises(AssertionError):
            tree.find_prefixsum_idx(tree.reduce() + 1e-4, trust_caller=False)
        with pytest.raises(AssertionError):
            tree.find_prefixsum_idx(-1e-6, trust_caller=False)

        assert (tree.find_prefixsum_idx(0) == 1)
        assert (tree.find_prefixsum_idx(0.09) == 1)
        assert (tree.find_prefixsum_idx(0.1) == 2)
        assert (tree.find_prefixsum_idx(0.59) == 2)
        assert (tree.find_prefixsum_idx(0.6) == 5)
        assert (tree.find_prefixsum_idx(0.799) == 5)
        assert (tree.find_prefixsum_idx(0.8) == 6)
        assert (tree.find_prefixsum_idx(tree.reduce()) == 6)


@pytest.mark.unittest
class TestMinSegmentTree:

    def test_create(self):
        tree = MinSegmentTree(capacity=16)
        assert (tree.operation == 'min')
        assert (tree.neutral_element == np.inf)
        assert (max(tree.value) == np.inf)
        assert (min(tree.value) == np.inf)

    def test_set_get_item(self):
        tree = MinSegmentTree(capacity=4)
        elements = [1, -10, 10, 7]
        get_result = []
        for idx, val in enumerate(elements):
            tree[idx] = val
            get_result.append(tree[idx])

        assert (elements == get_result)
        assert (tree.reduce() == min(elements))
        assert (tree.reduce(0, 3) == min(elements[:3]))
        assert (tree.reduce(0, 2) == min(elements[:2]))
        assert (tree.reduce(0, 1) == min(elements[:1]))
        assert (tree.reduce(1, 3) == min(elements[1:3]))
        assert (tree.reduce(1, 2) == min(elements[1:2]))
        assert (tree.reduce(2, 3) == min(elements[2:3]))