gomoku / DI-engine /ding /utils /default_helper.py
zjowowen's picture
init space
079c32c
from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict
import copy
from ditk import logging
import random
from functools import lru_cache # in python3.9, we can change to cache
import numpy as np
import torch
import treetensor.torch as ttorch
def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int:
"""
Overview:
Get shape[0] of data's torch tensor or treetensor
Arguments:
- data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed
Returns:
- shape[0] (:obj:`int`): first dimension length of data, usually the batchsize.
"""
if isinstance(data, list) or isinstance(data, tuple):
return get_shape0(data[0])
elif isinstance(data, dict):
for k, v in data.items():
return get_shape0(v)
elif isinstance(data, torch.Tensor):
return data.shape[0]
elif isinstance(data, ttorch.Tensor):
def fn(t):
item = list(t.values())[0]
if np.isscalar(item[0]):
return item[0]
else:
return fn(item)
return fn(data.shape)
else:
raise TypeError("Error in getting shape0, not support type: {}".format(data))
def lists_to_dicts(
data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]],
recursive: bool = False,
) -> Union[Mapping[object, object], NamedTuple]:
"""
Overview:
Transform a list of dicts to a dict of lists.
Arguments:
- data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`):
A dict of lists need to be transformed
- recursive (:obj:`bool`): whether recursively deals with dict element
Returns:
- newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result
Example:
>>> from ding.utils import *
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
{1: [1, 2], 10: [3, 4]}
"""
if len(data) == 0:
raise ValueError("empty data")
if isinstance(data[0], dict):
if recursive:
new_data = {}
for k in data[0].keys():
if isinstance(data[0][k], dict) and k != 'prev_state':
tmp = [data[b][k] for b in range(len(data))]
new_data[k] = lists_to_dicts(tmp)
else:
new_data[k] = [data[b][k] for b in range(len(data))]
else:
new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()}
elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple
new_data = type(data[0])(*list(zip(*data)))
else:
raise TypeError("not support element type: {}".format(type(data[0])))
return new_data
def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]:
"""
Overview:
Transform a dict of lists to a list of dicts.
Arguments:
- data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed
Returns:
- newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result
Example:
>>> from ding.utils import *
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
[{1: 1, 10: 3}, {1: 2, 10: 4}]
"""
new_data = [v for v in data.values()]
new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))]
return new_data
def override(cls: type) -> Callable[[
Callable,
], Callable]:
"""
Overview:
Annotation for documenting method overrides.
Arguments:
- cls (:obj:`type`): The superclass that provides the overridden method. If this
cls does not actually have the method, an error is raised.
"""
def check_override(method: Callable) -> Callable:
if method.__name__ not in dir(cls):
raise NameError("{} does not override any method of {}".format(method, cls))
return method
return check_override
def squeeze(data: object) -> object:
"""
Overview:
Squeeze data from tuple, list or dict to single object
Arguments:
- data (:obj:`object`): data to be squeezed
Example:
>>> a = (4, )
>>> a = squeeze(a)
>>> print(a)
>>> 4
"""
if isinstance(data, tuple) or isinstance(data, list):
if len(data) == 1:
return data[0]
else:
return tuple(data)
elif isinstance(data, dict):
if len(data) == 1:
return list(data.values())[0]
return data
default_get_set = set()
def default_get(
data: dict,
name: str,
default_value: Optional[Any] = None,
default_fn: Optional[Callable] = None,
judge_fn: Optional[Callable] = None
) -> Any:
"""
Overview:
Getting the value by input, checks generically on the inputs with \
at least ``data`` and ``name``. If ``name`` exists in ``data``, \
get the value at ``name``; else, add ``name`` to ``default_get_set``\
with value generated by \
``default_fn`` (or directly as ``default_value``) that \
is checked by `` judge_fn`` to be legal.
Arguments:
- data(:obj:`dict`): Data input dictionary
- name(:obj:`str`): Key name
- default_value(:obj:`Optional[Any]`) = None,
- default_fn(:obj:`Optional[Callable]`) = Value
- judge_fn(:obj:`Optional[Callable]`) = None
Returns:
- ret(:obj:`list`): Splitted data
- residual(:obj:`list`): Residule list
"""
if name in data:
return data[name]
else:
assert default_value is not None or default_fn is not None
value = default_fn() if default_fn is not None else default_value
if judge_fn:
assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value))
if name not in default_get_set:
logging.warning("{} use default value {}".format(name, value))
default_get_set.add(name)
return value
def list_split(data: list, step: int) -> List[list]:
"""
Overview:
Split list of data by step.
Arguments:
- data(:obj:`list`): List of data for spliting
- step(:obj:`int`): Number of step for spliting
Returns:
- ret(:obj:`list`): List of splitted data.
- residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``.
Example:
>>> list_split([1,2,3,4],2)
([[1, 2], [3, 4]], None)
>>> list_split([1,2,3,4],3)
([[1, 2, 3]], [4])
"""
if len(data) < step:
return [], data
ret = []
divide_num = len(data) // step
for i in range(divide_num):
start, end = i * step, (i + 1) * step
ret.append(data[start:end])
if divide_num * step < len(data):
residual = data[divide_num * step:]
else:
residual = None
return ret, residual
def error_wrapper(fn, default_ret, warning_msg=""):
"""
Overview:
wrap the function, so that any Exception in the function will be catched and return the default_ret
Arguments:
- fn (:obj:`Callable`): the function to be wraped
- default_ret (:obj:`obj`): the default return when an Exception occurred in the function
Returns:
- wrapper (:obj:`Callable`): the wrapped function
Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
>>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink.
>>> if is_fake_link:
>>> return 0
>>> return error_wrapper(link.get_rank, 0)()
"""
def wrapper(*args, **kwargs):
try:
ret = fn(*args, **kwargs)
except Exception as e:
ret = default_ret
if warning_msg != "":
one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e))
return ret
return wrapper
class LimitedSpaceContainer:
"""
Overview:
A space simulator.
Interfaces:
``__init__``, ``get_residual_space``, ``release_space``
"""
def __init__(self, min_val: int, max_val: int) -> None:
"""
Overview:
Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization.
Arguments:
- min_val (:obj:`int`): Min volume of the container, usually 0.
- max_val (:obj:`int`): Max volume of the container.
"""
self.min_val = min_val
self.max_val = max_val
assert (max_val >= min_val)
self.cur = self.min_val
def get_residual_space(self) -> int:
"""
Overview:
Get all residual pieces of space. Set ``cur`` to ``max_val``
Arguments:
- ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``.
"""
ret = self.max_val - self.cur
self.cur = self.max_val
return ret
def acquire_space(self) -> bool:
"""
Overview:
Try to get one pice of space. If there is one, return True; Otherwise return False.
Returns:
- flag (:obj:`bool`): Whether there is any piece of residual space.
"""
if self.cur < self.max_val:
self.cur += 1
return True
else:
return False
def release_space(self) -> None:
"""
Overview:
Release only one piece of space. Decrement ``cur``, but ensure it won't be negative.
"""
self.cur = max(self.min_val, self.cur - 1)
def increase_space(self) -> None:
"""
Overview:
Increase one piece in space. Increment ``max_val``.
"""
self.max_val += 1
def decrease_space(self) -> None:
"""
Overview:
Decrease one piece in space. Decrement ``max_val``.
"""
self.max_val -= 1
def deep_merge_dicts(original: dict, new_dict: dict) -> dict:
"""
Overview:
Merge two dicts by calling ``deep_update``
Arguments:
- original (:obj:`dict`): Dict 1.
- new_dict (:obj:`dict`): Dict 2.
Returns:
- merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged.
"""
original = original or {}
new_dict = new_dict or {}
merged = copy.deepcopy(original)
if new_dict: # if new_dict is neither empty dict nor None
deep_update(merged, new_dict, True, [])
return merged
def deep_update(
original: dict,
new_dict: dict,
new_keys_allowed: bool = False,
whitelist: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None
):
"""
Overview:
Update original dict with values from new_dict recursively.
Arguments:
- original (:obj:`dict`): Dictionary with default values.
- new_dict (:obj:`dict`): Dictionary with values to be updated
- new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
- whitelist (:obj:`Optional[List[str]]`):
List of keys that correspond to dict
values where new subkeys can be introduced. This is only at the top
level.
- override_all_if_type_changes(:obj:`Optional[List[str]]`):
List of top level
keys with value=dict, for which we always simply override the
entire value (:obj:`dict`), if the "type" key in that value dict changes.
.. note::
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
"""
whitelist = whitelist or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
# Both original value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Whitelisted key -> ok to add new subkeys.
elif k in whitelist:
deep_update(original[k], value, True)
# Non-whitelisted key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def flatten_dict(data: dict, delimiter: str = "/") -> dict:
"""
Overview:
Flatten the dict, see example
Arguments:
- data (:obj:`dict`): Original nested dict
- delimiter (str): Delimiter of the keys of the new dict
Returns:
- data (:obj:`dict`): Flattened nested dict
Example:
>>> a
{'a': {'b': 100}}
>>> flatten_dict(a)
{'a/b': 100}
"""
data = copy.deepcopy(data)
while any(isinstance(v, dict) for v in data.values()):
remove = []
add = {}
for key, value in data.items():
if isinstance(value, dict):
for subkey, v in value.items():
add[delimiter.join([key, subkey])] = v
remove.append(key)
data.update(add)
for k in remove:
del data[k]
return data
def set_pkg_seed(seed: int, use_cuda: bool = True) -> None:
"""
Overview:
Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\
This is usaually used in entry scipt in the section of setting random seed for all package and instance
Argument:
- seed(:obj:`int`): Set seed
- use_cuda(:obj:`bool`) Whether use cude
Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py
>>> ...
# Set random seed for all package and instance
>>> collector_env.seed(seed)
>>> evaluator_env.seed(seed, dynamic_seed=False)
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
>>> ...
# Set up RL Policy, etc.
>>> ...
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if use_cuda and torch.cuda.is_available():
torch.cuda.manual_seed(seed)
@lru_cache()
def one_time_warning(warning_msg: str) -> None:
"""
Overview:
Print warning message only once.
Arguments:
- warning_msg (:obj:`str`): Warning message.
"""
logging.warning(warning_msg)
def split_fn(data, indices, start, end):
"""
Overview:
Split data by indices
Arguments:
- data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed
- indices (:obj:`np.ndarray`): indices to split
- start (:obj:`int`): start index
- end (:obj:`int`): end index
"""
if data is None:
return None
elif isinstance(data, list):
return [split_fn(d, indices, start, end) for d in data]
elif isinstance(data, dict):
return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()}
elif isinstance(data, str):
return data
else:
return data[indices[start:end]]
def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict:
"""
Overview:
Split data into batches
Arguments:
- data (:obj:`dict`): data to be analysed
- split_size (:obj:`int`): split size
- shuffle (:obj:`bool`): whether shuffle
"""
assert isinstance(data, dict), type(data)
length = []
for k, v in data.items():
if v is None:
continue
elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']:
length.append(len(v))
elif isinstance(v, list) or isinstance(v, tuple):
if isinstance(v[0], str):
# some buffer data contains useless string infos, such as 'buffer_id',
# which should not be split, so we just skip it
continue
else:
length.append(get_shape0(v[0]))
elif isinstance(v, dict):
length.append(len(v[list(v.keys())[0]]))
else:
length.append(len(v))
assert len(length) > 0
# assert len(set(length)) == 1, "data values must have the same length: {}".format(length)
# if continuous action, data['logit'] is list of length 2
length = length[0]
assert split_size >= 1
if shuffle:
indices = np.random.permutation(length)
else:
indices = np.arange(length)
for i in range(0, length, split_size):
if i + split_size > length:
i = length - split_size
batch = split_fn(data, indices, i, i + split_size)
yield batch
class RunningMeanStd(object):
"""
Overview:
Wrapper to update new variable, new mean, and new count
Interfaces:
``__init__``, ``update``, ``reset``, ``new_shape``
Properties:
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
"""
def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')):
"""
Overview:
Initialize ``self.`` See ``help(type(self))`` for accurate \
signature; setup the properties.
Arguments:
- env (:obj:`gym.Env`): the environment to wrap.
- epsilon (:obj:`Float`): the epsilon used for self for the std output
- shape (:obj: `np.array`): the np array shape used for the expression \
of this wrapper on attibutes of mean and variance
"""
self._epsilon = epsilon
self._shape = shape
self._device = device
self.reset()
def update(self, x):
"""
Overview:
Update mean, variable, and count
Arguments:
- ``x``: the batch
"""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
new_count = batch_count + self._count
mean_delta = batch_mean - self._mean
new_mean = self._mean + mean_delta * batch_count / new_count
# this method for calculating new variable might be numerically unstable
m_a = self._var * self._count
m_b = batch_var * batch_count
m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count
new_var = m2 / new_count
self._mean = new_mean
self._var = new_var
self._count = new_count
def reset(self):
"""
Overview:
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
"""
if len(self._shape) > 0:
self._mean = np.zeros(self._shape, 'float32')
self._var = np.ones(self._shape, 'float32')
else:
self._mean, self._var = 0., 1.
self._count = self._epsilon
@property
def mean(self) -> np.ndarray:
"""
Overview:
Property ``mean`` gotten from ``self._mean``
"""
if np.isscalar(self._mean):
return self._mean
else:
return torch.FloatTensor(self._mean).to(self._device)
@property
def std(self) -> np.ndarray:
"""
Overview:
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
"""
std = np.sqrt(self._var + 1e-8)
if np.isscalar(std):
return std
else:
return torch.FloatTensor(std).to(self._device)
@staticmethod
def new_shape(obs_shape, act_shape, rew_shape):
"""
Overview:
Get new shape of observation, acton, and reward; in this case unchanged.
Arguments:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
Returns:
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
"""
return obs_shape, act_shape, rew_shape
def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Make the key of dict into legal python identifier string so that it is
compatible with some python magic method such as ``__getattr``.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys.
"""
def legalization(s: str) -> str:
if s[0].isdigit():
s = '_' + s
return s.replace('.', '_')
new_data = {}
for k in data:
new_k = legalization(k)
new_data[new_k] = data[k]
return new_data
def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]:
"""
Overview:
Remove illegal item in dict info, like str, which is not compatible with Tensor.
Arguments:
- data (:obj:`Dict[str, Any]`): The original dict data.
Return:
- new_data (:obj:`Dict[str, Any]`): The new dict data without legal items.
"""
new_data = {}
for k, v in data.items():
if isinstance(v, str):
continue
new_data[k] = data[k]
return new_data