File size: 3,097 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import numpy as np
import pytest
import ding
ding.enable_numba = False # noqa
from ding.utils import SumSegmentTree, MinSegmentTree # noqa
@pytest.mark.unittest
class TestSumSegmentTree:
def test_create(self):
with pytest.raises(AssertionError):
tree = SumSegmentTree(capacity=13)
tree = SumSegmentTree(capacity=16)
assert (tree.operation == 'sum')
assert (tree.neutral_element == 0.)
assert (max(tree.value) == 0.)
assert (min(tree.value) == 0.)
def test_set_get_item(self):
tree = SumSegmentTree(capacity=4)
elements = [1, 5, 4, 7]
get_result = []
for idx, val in enumerate(elements):
tree[idx] = val
get_result.append(tree[idx])
assert (elements == get_result)
assert (tree.reduce() == sum(elements))
assert (tree.reduce(0, 3) == sum(elements[:3]))
assert (tree.reduce(0, 2) == sum(elements[:2]))
assert (tree.reduce(0, 1) == sum(elements[:1]))
assert (tree.reduce(1, 3) == sum(elements[1:3]))
assert (tree.reduce(1, 2) == sum(elements[1:2]))
assert (tree.reduce(2, 3) == sum(elements[2:3]))
with pytest.raises(AssertionError):
tree.reduce(2, 2)
def test_find_prefixsum_idx(self):
tree = SumSegmentTree(capacity=8)
elements = [0, 0.1, 0.5, 0, 0, 0.2, 0.8, 0]
for idx, val in enumerate(elements):
tree[idx] = val
with pytest.raises(AssertionError):
tree.find_prefixsum_idx(tree.reduce() + 1e-4, trust_caller=False)
with pytest.raises(AssertionError):
tree.find_prefixsum_idx(-1e-6, trust_caller=False)
assert (tree.find_prefixsum_idx(0) == 1)
assert (tree.find_prefixsum_idx(0.09) == 1)
assert (tree.find_prefixsum_idx(0.1) == 2)
assert (tree.find_prefixsum_idx(0.59) == 2)
assert (tree.find_prefixsum_idx(0.6) == 5)
assert (tree.find_prefixsum_idx(0.799) == 5)
assert (tree.find_prefixsum_idx(0.8) == 6)
assert (tree.find_prefixsum_idx(tree.reduce()) == 6)
@pytest.mark.unittest
class TestMinSegmentTree:
def test_create(self):
tree = MinSegmentTree(capacity=16)
assert (tree.operation == 'min')
assert (tree.neutral_element == np.inf)
assert (max(tree.value) == np.inf)
assert (min(tree.value) == np.inf)
def test_set_get_item(self):
tree = MinSegmentTree(capacity=4)
elements = [1, -10, 10, 7]
get_result = []
for idx, val in enumerate(elements):
tree[idx] = val
get_result.append(tree[idx])
assert (elements == get_result)
assert (tree.reduce() == min(elements))
assert (tree.reduce(0, 3) == min(elements[:3]))
assert (tree.reduce(0, 2) == min(elements[:2]))
assert (tree.reduce(0, 1) == min(elements[:1]))
assert (tree.reduce(1, 3) == min(elements[1:3]))
assert (tree.reduce(1, 2) == min(elements[1:2]))
assert (tree.reduce(2, 3) == min(elements[2:3]))
|