|
from functools import partial, lru_cache |
|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
|
|
import ding |
|
from .default_helper import one_time_warning |
|
|
|
|
|
@lru_cache() |
|
def njit(): |
|
""" |
|
Overview: |
|
Decorator to compile a function using numba. |
|
""" |
|
|
|
try: |
|
if ding.enable_numba: |
|
import numba |
|
from numba import njit as _njit |
|
version = numba.__version__ |
|
middle_version = version.split(".")[1] |
|
if int(middle_version) < 53: |
|
_njit = partial |
|
one_time_warning( |
|
"Due to your numba version <= 0.53.0, DI-engine disables it. And you can install \ |
|
numba==0.53.0 if you want to speed up something" |
|
) |
|
else: |
|
_njit = partial |
|
except ImportError: |
|
one_time_warning("If you want to use numba to speed up segment tree, please install numba first") |
|
_njit = partial |
|
return _njit |
|
|
|
|
|
class SegmentTree: |
|
""" |
|
Overview: |
|
Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value, |
|
non-leaf nodes are to do some operations on its left and right child. |
|
Interfaces: |
|
``__init__``, ``reduce``, ``__setitem__``, ``__getitem__`` |
|
""" |
|
|
|
def __init__(self, capacity: int, operation: Callable, neutral_element: Optional[float] = None) -> None: |
|
""" |
|
Overview: |
|
Initialize the segment tree. Tree's root node is at index 1. |
|
Arguments: |
|
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes), should be the power of 2. |
|
- operation (:obj:`function`): The operation function to construct the tree, e.g. sum, max, min, etc. |
|
- neutral_element (:obj:`float` or :obj:`None`): The value of the neutral element, which is used to init \ |
|
all nodes value in the tree. |
|
""" |
|
assert capacity > 0 and capacity & (capacity - 1) == 0 |
|
self.capacity = capacity |
|
self.operation = operation |
|
|
|
if neutral_element is None: |
|
if operation == 'sum': |
|
neutral_element = 0. |
|
elif operation == 'min': |
|
neutral_element = np.inf |
|
elif operation == 'max': |
|
neutral_element = -np.inf |
|
else: |
|
raise ValueError("operation argument should be in min, max, sum (built in python functions).") |
|
self.neutral_element = neutral_element |
|
|
|
|
|
self.value = np.full([capacity * 2], neutral_element) |
|
self._compile() |
|
|
|
def reduce(self, start: int = 0, end: Optional[int] = None) -> float: |
|
""" |
|
Overview: |
|
Reduce the tree in range ``[start, end)`` |
|
Arguments: |
|
- start (:obj:`int`): Start index(relative index, the first leaf node is 0), default set to 0 |
|
- end (:obj:`int` or :obj:`None`): End index(relative index), default set to ``self.capacity`` |
|
Returns: |
|
- reduce_result (:obj:`float`): The reduce result value, which is dependent on data type and operation |
|
""" |
|
|
|
if end is None: |
|
end = self.capacity |
|
assert (start < end) |
|
|
|
start += self.capacity |
|
end += self.capacity |
|
return _reduce(self.value, start, end, self.neutral_element, self.operation) |
|
|
|
def __setitem__(self, idx: int, val: float) -> None: |
|
""" |
|
Overview: |
|
Set ``leaf[idx] = val``; Then update the related nodes. |
|
Arguments: |
|
- idx (:obj:`int`): Leaf node index(relative index), should add ``capacity`` to change to absolute index. |
|
- val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. |
|
""" |
|
assert (0 <= idx < self.capacity), idx |
|
|
|
_setitem(self.value, idx + self.capacity, val, self.operation) |
|
|
|
def __getitem__(self, idx: int) -> float: |
|
""" |
|
Overview: |
|
Get ``leaf[idx]`` |
|
Arguments: |
|
- idx (:obj:`int`): Leaf node ``index(relative index)``, add ``capacity`` to change to absolute index. |
|
Returns: |
|
- val (:obj:`float`): The value of ``leaf[idx]`` |
|
""" |
|
assert (0 <= idx < self.capacity) |
|
return self.value[idx + self.capacity] |
|
|
|
def _compile(self) -> None: |
|
""" |
|
Overview: |
|
Compile the functions using numba. |
|
""" |
|
|
|
f64 = np.array([0, 1], dtype=np.float64) |
|
f32 = np.array([0, 1], dtype=np.float32) |
|
i64 = np.array([0, 1], dtype=np.int64) |
|
for d in [f64, f32, i64]: |
|
_setitem(d, 0, 3.0, 'sum') |
|
_reduce(d, 0, 1, 0.0, 'min') |
|
_find_prefixsum_idx(d, 1, 0.5, 0.0) |
|
|
|
|
|
class SumSegmentTree(SegmentTree): |
|
""" |
|
Overview: |
|
Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``. |
|
Interfaces: |
|
``__init__``, ``find_prefixsum_idx`` |
|
""" |
|
|
|
def __init__(self, capacity: int) -> None: |
|
""" |
|
Overview: |
|
Init sum segment tree by passing ``operation='sum'`` |
|
Arguments: |
|
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). |
|
""" |
|
super(SumSegmentTree, self).__init__(capacity, operation='sum') |
|
|
|
def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int: |
|
""" |
|
Overview: |
|
Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) |
|
and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) |
|
Arguments: |
|
- prefixsum (:obj:`float`): The target prefixsum. |
|
- trust_caller (:obj:`bool`): Whether to trust caller, which means whether to check whether \ |
|
this tree's sum is greater than the input ``prefixsum`` by calling ``reduce`` function. |
|
Default set to True. |
|
Returns: |
|
- idx (:obj:`int`): Eligible index. |
|
""" |
|
if not trust_caller: |
|
assert 0 <= prefixsum <= self.reduce() + 1e-5, prefixsum |
|
return _find_prefixsum_idx(self.value, self.capacity, prefixsum, self.neutral_element) |
|
|
|
|
|
class MinSegmentTree(SegmentTree): |
|
""" |
|
Overview: |
|
Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``. |
|
Interfaces: |
|
``__init__`` |
|
""" |
|
|
|
def __init__(self, capacity: int) -> None: |
|
""" |
|
Overview: |
|
Initialize sum segment tree by passing ``operation='min'`` |
|
Arguments: |
|
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). |
|
""" |
|
super(MinSegmentTree, self).__init__(capacity, operation='min') |
|
|
|
|
|
@njit() |
|
def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None: |
|
""" |
|
Overview: |
|
Set ``tree[idx] = val``; Then update the related nodes. |
|
Arguments: |
|
- tree (:obj:`np.ndarray`): The tree array. |
|
- idx (:obj:`int`): The index of the leaf node. |
|
- val (:obj:`float`): The value that will be assigned to ``leaf[idx]``. |
|
- operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. |
|
""" |
|
|
|
tree[idx] = val |
|
|
|
while idx > 1: |
|
idx = idx >> 1 |
|
left, right = tree[2 * idx], tree[2 * idx + 1] |
|
if operation == 'sum': |
|
tree[idx] = left + right |
|
elif operation == 'min': |
|
tree[idx] = min([left, right]) |
|
|
|
|
|
@njit() |
|
def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float: |
|
""" |
|
Overview: |
|
Reduce the tree in range ``[start, end)`` |
|
Arguments: |
|
- tree (:obj:`np.ndarray`): The tree array. |
|
- start (:obj:`int`): Start index(relative index, the first leaf node is 0). |
|
- end (:obj:`int`): End index(relative index). |
|
- neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ |
|
all nodes value in the tree. |
|
- operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc. |
|
""" |
|
|
|
|
|
result = neutral_element |
|
while start < end: |
|
if start & 1: |
|
|
|
if operation == 'sum': |
|
result = result + tree[start] |
|
elif operation == 'min': |
|
result = min([result, tree[start]]) |
|
start += 1 |
|
if end & 1: |
|
|
|
end -= 1 |
|
if operation == 'sum': |
|
result = result + tree[end] |
|
elif operation == 'min': |
|
result = min([result, tree[end]]) |
|
|
|
start = start >> 1 |
|
end = end >> 1 |
|
return result |
|
|
|
|
|
@njit() |
|
def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int: |
|
""" |
|
Overview: |
|
Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i) |
|
and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1) |
|
Arguments: |
|
- tree (:obj:`np.ndarray`): The tree array. |
|
- capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes). |
|
- prefixsum (:obj:`float`): The target prefixsum. |
|
- neutral_element (:obj:`float`): The value of the neutral element, which is used to init \ |
|
all nodes value in the tree. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
idx = 1 |
|
while idx < capacity: |
|
child_base = 2 * idx |
|
if tree[child_base] > prefixsum: |
|
idx = child_base |
|
else: |
|
prefixsum -= tree[child_base] |
|
idx = child_base + 1 |
|
|
|
|
|
|
|
if idx == 2 * capacity - 1 and tree[idx] == neutral_element: |
|
tmp = idx |
|
while tmp >= capacity and tree[tmp] == neutral_element: |
|
tmp -= 1 |
|
if tmp != capacity: |
|
idx = tmp |
|
else: |
|
raise ValueError("All elements in tree are the neutral_element(0), can't find non-zero element") |
|
assert (tree[idx] != neutral_element) |
|
return idx - capacity |
|
|