gomoku / DI-engine /ding /utils /segment_tree.py
zjowowen's picture
init space
079c32c
from functools import partial, lru_cache
from typing import Callable, Optional
import numpy as np
import ding
from .default_helper import one_time_warning
@lru_cache()
def njit():
"""
Overview:
Decorator to compile a function using numba.
"""
try:
if ding.enable_numba:
import numba
from numba import njit as _njit
version = numba.__version__
middle_version = version.split(".")[1]
if int(middle_version) < 53:
_njit = partial # noqa
one_time_warning(
"Due to your numba version <= 0.53.0, DI-engine disables it. And you can install \
numba==0.53.0 if you want to speed up something"
)
else:
_njit = partial
except ImportError:
one_time_warning("If you want to use numba to speed up segment tree, please install numba first")
_njit = partial
return _njit
class SegmentTree:
"""
Overview:
Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value,
non-leaf nodes are to do some operations on its left and right child.
Interfaces:
``__init__``, ``reduce``, ``__setitem__``, ``__getitem__``
"""
def __init__(self, capacity: int, operation: Callable, neutral_element: Optional[float] = None) -> None:
"""
Overview:
Initialize the segment tree. Tree's root node is at index 1.
Arguments:
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes), should be the power of 2.
- operation (:obj:`function`): The operation function to construct the tree, e.g. sum, max, min, etc.
- neutral_element (:obj:`float` or :obj:`None`): The value of the neutral element, which is used to init \
all nodes value in the tree.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0
self.capacity = capacity
self.operation = operation
# Set neutral value(initial value) for all elements.
if neutral_element is None:
if operation == 'sum':
neutral_element = 0.
elif operation == 'min':
neutral_element = np.inf
elif operation == 'max':
neutral_element = -np.inf
else:
raise ValueError("operation argument should be in min, max, sum (built in python functions).")
self.neutral_element = neutral_element
# Index 1 is the root; Index ranging in [capacity, 2 * capacity - 1] are the leaf nodes.
# For each parent node with index i, left child is value[2*i] and right child is value[2*i+1].
self.value = np.full([capacity * 2], neutral_element)
self._compile()
def reduce(self, start: int = 0, end: Optional[int] = None) -> float:
"""
Overview:
Reduce the tree in range ``[start, end)``
Arguments:
- start (:obj:`int`): Start index(relative index, the first leaf node is 0), default set to 0
- end (:obj:`int` or :obj:`None`): End index(relative index), default set to ``self.capacity``
Returns:
- reduce_result (:obj:`float`): The reduce result value, which is dependent on data type and operation
"""
# TODO(nyz) check if directly reduce from the array(value) can be faster
if end is None:
end = self.capacity
assert (start < end)
# Change to absolute leaf index by adding capacity.
start += self.capacity
end += self.capacity
return _reduce(self.value, start, end, self.neutral_element, self.operation)
def __setitem__(self, idx: int, val: float) -> None:
"""
Overview:
Set ``leaf[idx] = val``; Then update the related nodes.
Arguments:
- idx (:obj:`int`): Leaf node index(relative index), should add ``capacity`` to change to absolute index.
- val (:obj:`float`): The value that will be assigned to ``leaf[idx]``.
"""
assert (0 <= idx < self.capacity), idx
# ``idx`` should add ``capacity`` to change to absolute index.
_setitem(self.value, idx + self.capacity, val, self.operation)
def __getitem__(self, idx: int) -> float:
"""
Overview:
Get ``leaf[idx]``
Arguments:
- idx (:obj:`int`): Leaf node ``index(relative index)``, add ``capacity`` to change to absolute index.
Returns:
- val (:obj:`float`): The value of ``leaf[idx]``
"""
assert (0 <= idx < self.capacity)
return self.value[idx + self.capacity]
def _compile(self) -> None:
"""
Overview:
Compile the functions using numba.
"""
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
i64 = np.array([0, 1], dtype=np.int64)
for d in [f64, f32, i64]:
_setitem(d, 0, 3.0, 'sum')
_reduce(d, 0, 1, 0.0, 'min')
_find_prefixsum_idx(d, 1, 0.5, 0.0)
class SumSegmentTree(SegmentTree):
"""
Overview:
Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``.
Interfaces:
``__init__``, ``find_prefixsum_idx``
"""
def __init__(self, capacity: int) -> None:
"""
Overview:
Init sum segment tree by passing ``operation='sum'``
Arguments:
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
"""
super(SumSegmentTree, self).__init__(capacity, operation='sum')
def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int:
"""
Overview:
Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i)
and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1)
Arguments:
- prefixsum (:obj:`float`): The target prefixsum.
- trust_caller (:obj:`bool`): Whether to trust caller, which means whether to check whether \
this tree's sum is greater than the input ``prefixsum`` by calling ``reduce`` function.
Default set to True.
Returns:
- idx (:obj:`int`): Eligible index.
"""
if not trust_caller:
assert 0 <= prefixsum <= self.reduce() + 1e-5, prefixsum
return _find_prefixsum_idx(self.value, self.capacity, prefixsum, self.neutral_element)
class MinSegmentTree(SegmentTree):
"""
Overview:
Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``.
Interfaces:
``__init__``
"""
def __init__(self, capacity: int) -> None:
"""
Overview:
Initialize sum segment tree by passing ``operation='min'``
Arguments:
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
"""
super(MinSegmentTree, self).__init__(capacity, operation='min')
@njit()
def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None:
"""
Overview:
Set ``tree[idx] = val``; Then update the related nodes.
Arguments:
- tree (:obj:`np.ndarray`): The tree array.
- idx (:obj:`int`): The index of the leaf node.
- val (:obj:`float`): The value that will be assigned to ``leaf[idx]``.
- operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc.
"""
tree[idx] = val
# Update from specified node to the root node
while idx > 1:
idx = idx >> 1 # To parent node idx
left, right = tree[2 * idx], tree[2 * idx + 1]
if operation == 'sum':
tree[idx] = left + right
elif operation == 'min':
tree[idx] = min([left, right])
@njit()
def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float:
"""
Overview:
Reduce the tree in range ``[start, end)``
Arguments:
- tree (:obj:`np.ndarray`): The tree array.
- start (:obj:`int`): Start index(relative index, the first leaf node is 0).
- end (:obj:`int`): End index(relative index).
- neutral_element (:obj:`float`): The value of the neutral element, which is used to init \
all nodes value in the tree.
- operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc.
"""
# Nodes in 【start, end) will be aggregated
result = neutral_element
while start < end:
if start & 1:
# If current start node (tree[start]) is a right child node, operate on start node and increase start by 1
if operation == 'sum':
result = result + tree[start]
elif operation == 'min':
result = min([result, tree[start]])
start += 1
if end & 1:
# If current end node (tree[end - 1]) is right child node, decrease end by 1 and operate on end node
end -= 1
if operation == 'sum':
result = result + tree[end]
elif operation == 'min':
result = min([result, tree[end]])
# Both start and end transform to respective parent node
start = start >> 1
end = end >> 1
return result
@njit()
def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int:
"""
Overview:
Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i)
and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1)
Arguments:
- tree (:obj:`np.ndarray`): The tree array.
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
- prefixsum (:obj:`float`): The target prefixsum.
- neutral_element (:obj:`float`): The value of the neutral element, which is used to init \
all nodes value in the tree.
"""
# The function is to find a non-leaf node's index which satisfies:
# self.value[idx] > input prefixsum and self.value[idx + 1] <= input prefixsum
# In other words, we can assume that there are intervals: [num_0, num_1), [num_1, num_2), ... [num_k, num_k+1),
# the function is to find input prefixsum falls in which interval and return the interval's index.
idx = 1 # start from root node
while idx < capacity:
child_base = 2 * idx
if tree[child_base] > prefixsum:
idx = child_base
else:
prefixsum -= tree[child_base]
idx = child_base + 1
# Special case: The last element of ``self.value`` is neutral_element(0),
# and caller wants to ``find_prefixsum_idx(root_value)``.
# However, input prefixsum should be smaller than root_value.
if idx == 2 * capacity - 1 and tree[idx] == neutral_element:
tmp = idx
while tmp >= capacity and tree[tmp] == neutral_element:
tmp -= 1
if tmp != capacity:
idx = tmp
else:
raise ValueError("All elements in tree are the neutral_element(0), can't find non-zero element")
assert (tree[idx] != neutral_element)
return idx - capacity