|
from typing import List, Dict, Tuple |
|
from ditk import logging |
|
from copy import deepcopy |
|
from easydict import EasyDict |
|
from torch.utils.data import Dataset |
|
from dataclasses import dataclass |
|
|
|
import pickle |
|
import easydict |
|
import torch |
|
import numpy as np |
|
|
|
from ding.utils.bfs_helper import get_vi_sequence |
|
from ding.utils import DATASET_REGISTRY, import_module, DatasetNormalizer |
|
from ding.rl_utils import discount_cumsum |
|
|
|
|
|
@dataclass |
|
class DatasetStatistics: |
|
""" |
|
Overview: |
|
Dataset statistics. |
|
""" |
|
mean: np.ndarray |
|
std: np.ndarray |
|
action_bounds: np.ndarray |
|
|
|
|
|
@DATASET_REGISTRY.register('naive') |
|
class NaiveRLDataset(Dataset): |
|
""" |
|
Overview: |
|
Naive RL dataset, which is used for offline RL algorithms. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
""" |
|
|
|
def __init__(self, cfg) -> None: |
|
""" |
|
Overview: |
|
Initialization method. |
|
Arguments: |
|
- cfg (:obj:`dict`): Config dict. |
|
""" |
|
|
|
assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg)) |
|
if isinstance(cfg, EasyDict): |
|
self._data_path = cfg.policy.collect.data_path |
|
elif isinstance(cfg, str): |
|
self._data_path = cfg |
|
with open(self._data_path, 'rb') as f: |
|
self._data: List[Dict[str, torch.Tensor]] = pickle.load(f) |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return len(self._data) |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
""" |
|
|
|
return self._data[idx] |
|
|
|
|
|
@DATASET_REGISTRY.register('d4rl') |
|
class D4RLDataset(Dataset): |
|
""" |
|
Overview: |
|
D4RL dataset, which is used for offline RL algorithms. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
Properties: |
|
- mean (:obj:`np.ndarray`): Mean of the dataset. |
|
- std (:obj:`np.ndarray`): Std of the dataset. |
|
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. |
|
- statistics (:obj:`dict`): Statistics of the dataset. |
|
""" |
|
|
|
def __init__(self, cfg: dict) -> None: |
|
""" |
|
Overview: |
|
Initialization method. |
|
Arguments: |
|
- cfg (:obj:`dict`): Config dict. |
|
""" |
|
|
|
import gym |
|
try: |
|
import d4rl |
|
except ImportError: |
|
import sys |
|
logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl") |
|
sys.exit(1) |
|
|
|
|
|
data_path = cfg.policy.collect.get('data_path', None) |
|
env_id = cfg.env.env_id |
|
|
|
|
|
if data_path: |
|
d4rl.set_dataset_path(data_path) |
|
env = gym.make(env_id) |
|
dataset = d4rl.qlearning_dataset(env) |
|
self._cal_statistics(dataset, env) |
|
try: |
|
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: |
|
dataset = self._normalize_states(dataset) |
|
except (KeyError, AttributeError): |
|
|
|
pass |
|
self._data = [] |
|
self._load_d4rl(dataset) |
|
|
|
@property |
|
def data(self) -> List: |
|
return self._data |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return len(self._data) |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
""" |
|
|
|
return self._data[idx] |
|
|
|
def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None: |
|
""" |
|
Overview: |
|
Load the d4rl dataset. |
|
Arguments: |
|
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. |
|
""" |
|
|
|
for i in range(len(dataset['observations'])): |
|
trans_data = {} |
|
trans_data['obs'] = torch.from_numpy(dataset['observations'][i]) |
|
trans_data['next_obs'] = torch.from_numpy(dataset['next_observations'][i]) |
|
trans_data['action'] = torch.from_numpy(dataset['actions'][i]) |
|
trans_data['reward'] = torch.tensor(dataset['rewards'][i]) |
|
trans_data['done'] = dataset['terminals'][i] |
|
self._data.append(trans_data) |
|
|
|
def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True): |
|
""" |
|
Overview: |
|
Calculate the statistics of the dataset. |
|
Arguments: |
|
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. |
|
- env (:obj:`gym.Env`): The environment. |
|
- eps (:obj:`float`): Epsilon. |
|
""" |
|
|
|
self._mean = dataset['observations'].mean(0) |
|
self._std = dataset['observations'].std(0) + eps |
|
action_max = dataset['actions'].max(0) |
|
action_min = dataset['actions'].min(0) |
|
if add_action_buffer: |
|
action_buffer = 0.05 * (action_max - action_min) |
|
action_max = (action_max + action_buffer).clip(max=env.action_space.high) |
|
action_min = (action_min - action_buffer).clip(min=env.action_space.low) |
|
self._action_bounds = np.stack([action_min, action_max], axis=0) |
|
|
|
def _normalize_states(self, dataset): |
|
""" |
|
Overview: |
|
Normalize the states. |
|
Arguments: |
|
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. |
|
""" |
|
|
|
dataset['observations'] = (dataset['observations'] - self._mean) / self._std |
|
dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std |
|
return dataset |
|
|
|
@property |
|
def mean(self): |
|
""" |
|
Overview: |
|
Get the mean of the dataset. |
|
""" |
|
|
|
return self._mean |
|
|
|
@property |
|
def std(self): |
|
""" |
|
Overview: |
|
Get the std of the dataset. |
|
""" |
|
|
|
return self._std |
|
|
|
@property |
|
def action_bounds(self) -> np.ndarray: |
|
""" |
|
Overview: |
|
Get the action bounds of the dataset. |
|
""" |
|
|
|
return self._action_bounds |
|
|
|
@property |
|
def statistics(self) -> dict: |
|
""" |
|
Overview: |
|
Get the statistics of the dataset. |
|
""" |
|
|
|
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) |
|
|
|
|
|
@DATASET_REGISTRY.register('hdf5') |
|
class HDF5Dataset(Dataset): |
|
""" |
|
Overview: |
|
HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. |
|
The hdf5 format is a common format for storing large numerical arrays in Python. |
|
For more details, please refer to https://support.hdfgroup.org/HDF5/. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
Properties: |
|
- mean (:obj:`np.ndarray`): Mean of the dataset. |
|
- std (:obj:`np.ndarray`): Std of the dataset. |
|
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. |
|
- statistics (:obj:`dict`): Statistics of the dataset. |
|
""" |
|
|
|
def __init__(self, cfg: dict) -> None: |
|
""" |
|
Overview: |
|
Initialization method. |
|
Arguments: |
|
- cfg (:obj:`dict`): Config dict. |
|
""" |
|
|
|
try: |
|
import h5py |
|
except ImportError: |
|
import sys |
|
logging.warning("not found h5py package, please install it trough `pip install h5py ") |
|
sys.exit(1) |
|
data_path = cfg.policy.collect.get('data_path', None) |
|
if 'dataset' in cfg: |
|
self.context_len = cfg.dataset.context_len |
|
else: |
|
self.context_len = 0 |
|
data = h5py.File(data_path, 'r') |
|
self._load_data(data) |
|
self._cal_statistics() |
|
try: |
|
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: |
|
self._normalize_states() |
|
except (KeyError, AttributeError): |
|
|
|
pass |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return len(self._data['obs']) - self.context_len |
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
Arguments: |
|
- idx (:obj:`int`): The index of the dataset. |
|
""" |
|
|
|
if self.context_len == 0: |
|
return {k: self._data[k][idx] for k in self._data.keys()} |
|
else: |
|
block_size = self.context_len |
|
done_idx = idx + block_size |
|
idx = done_idx - block_size |
|
states = torch.as_tensor( |
|
np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32 |
|
).view(block_size, -1) |
|
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long) |
|
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32) |
|
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64) |
|
traj_mask = torch.ones(self.context_len, dtype=torch.long) |
|
return timesteps, states, actions, rtgs, traj_mask |
|
|
|
def _load_data(self, dataset: Dict[str, np.ndarray]) -> None: |
|
""" |
|
Overview: |
|
Load the dataset. |
|
Arguments: |
|
- dataset (:obj:`Dict[str, np.ndarray]`): The dataset. |
|
""" |
|
|
|
self._data = {} |
|
for k in dataset.keys(): |
|
logging.info(f'Load {k} data.') |
|
self._data[k] = dataset[k][:] |
|
|
|
def _cal_statistics(self, eps: float = 1e-3): |
|
""" |
|
Overview: |
|
Calculate the statistics of the dataset. |
|
Arguments: |
|
- eps (:obj:`float`): Epsilon. |
|
""" |
|
|
|
self._mean = self._data['obs'].mean(0) |
|
self._std = self._data['obs'].std(0) + eps |
|
action_max = self._data['action'].max(0) |
|
action_min = self._data['action'].min(0) |
|
buffer = 0.05 * (action_max - action_min) |
|
action_max = action_max.astype(float) + buffer |
|
action_min = action_max.astype(float) - buffer |
|
self._action_bounds = np.stack([action_min, action_max], axis=0) |
|
|
|
def _normalize_states(self): |
|
""" |
|
Overview: |
|
Normalize the states. |
|
""" |
|
|
|
self._data['obs'] = (self._data['obs'] - self._mean) / self._std |
|
self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std |
|
|
|
@property |
|
def mean(self): |
|
""" |
|
Overview: |
|
Get the mean of the dataset. |
|
""" |
|
|
|
return self._mean |
|
|
|
@property |
|
def std(self): |
|
""" |
|
Overview: |
|
Get the std of the dataset. |
|
""" |
|
|
|
return self._std |
|
|
|
@property |
|
def action_bounds(self) -> np.ndarray: |
|
""" |
|
Overview: |
|
Get the action bounds of the dataset. |
|
""" |
|
|
|
return self._action_bounds |
|
|
|
@property |
|
def statistics(self) -> dict: |
|
""" |
|
Overview: |
|
Get the statistics of the dataset. |
|
""" |
|
|
|
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) |
|
|
|
|
|
@DATASET_REGISTRY.register('d4rl_trajectory') |
|
class D4RLTrajectoryDataset(Dataset): |
|
""" |
|
Overview: |
|
D4RL trajectory dataset, which is used for offline RL algorithms. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
""" |
|
|
|
|
|
REF_MIN_SCORE = { |
|
'halfcheetah': -280.178953, |
|
'walker2d': 1.629008, |
|
'hopper': -20.272305, |
|
} |
|
|
|
REF_MAX_SCORE = { |
|
'halfcheetah': 12135.0, |
|
'walker2d': 4592.3, |
|
'hopper': 3234.3, |
|
} |
|
|
|
|
|
D4RL_DATASET_STATS = { |
|
'halfcheetah-medium-v2': { |
|
'state_mean': [ |
|
-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, |
|
-0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, |
|
5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, |
|
0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, |
|
0.013382787816226482 |
|
], |
|
'state_std': [ |
|
0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, |
|
0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, |
|
1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, |
|
5.671932697296143, 7.4982590675354 |
|
] |
|
}, |
|
'halfcheetah-medium-replay-v2': { |
|
'state_mean': [ |
|
-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, |
|
-0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, |
|
3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, |
|
0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, |
|
-0.015839405357837677 |
|
], |
|
'state_std': [ |
|
0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, |
|
0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, |
|
1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, |
|
6.085654258728027, 7.25300407409668 |
|
] |
|
}, |
|
'halfcheetah-medium-expert-v2': { |
|
'state_mean': [ |
|
-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, |
|
-0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, |
|
8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, |
|
0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314 |
|
], |
|
'state_std': [ |
|
0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, |
|
0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, |
|
1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, |
|
6.4811787605285645, 6.378620147705078 |
|
] |
|
}, |
|
'walker2d-medium-v2': { |
|
'state_mean': [ |
|
1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, |
|
-0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, |
|
-0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, |
|
0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654 |
|
], |
|
'state_std': [ |
|
0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, |
|
0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, |
|
1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, |
|
3.7445690631866455, 5.5851287841796875 |
|
] |
|
}, |
|
'walker2d-medium-replay-v2': { |
|
'state_mean': [ |
|
1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, |
|
-0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, |
|
-0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, |
|
-0.08934258669614792, -0.2992438077926636, -0.5984178185462952 |
|
], |
|
'state_std': [ |
|
0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, |
|
0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, |
|
2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, |
|
3.845186948776245, 5.4768385887146 |
|
] |
|
}, |
|
'walker2d-medium-expert-v2': { |
|
'state_mean': [ |
|
1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, |
|
0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, |
|
3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, |
|
-0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, |
|
-0.27366524934768677 |
|
], |
|
'state_std': [ |
|
0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, |
|
0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, |
|
1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, |
|
4.039782524108887, 5.891613960266113 |
|
] |
|
}, |
|
'hopper-medium-v2': { |
|
'state_mean': [ |
|
1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, |
|
2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, |
|
-0.18540096282958984, -0.28461286425590515 |
|
], |
|
'state_std': [ |
|
0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, |
|
0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, |
|
5.607253551483154 |
|
] |
|
}, |
|
'hopper-medium-replay-v2': { |
|
'state_mean': [ |
|
1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, |
|
0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, |
|
-0.5287045240402222, -0.14465883374214172, -0.19652697443962097 |
|
], |
|
'state_std': [ |
|
0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, |
|
1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, |
|
5.108601093292236 |
|
] |
|
}, |
|
'hopper-medium-expert-v2': { |
|
'state_mean': [ |
|
1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, |
|
0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, |
|
-0.1766270101070404, -0.11862941086292267, -0.12097819894552231 |
|
], |
|
'state_std': [ |
|
0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, |
|
0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, |
|
5.725032806396484 |
|
] |
|
}, |
|
} |
|
|
|
def __init__(self, cfg: dict) -> None: |
|
""" |
|
Overview: |
|
Initialization method. |
|
Arguments: |
|
- cfg (:obj:`dict`): Config dict. |
|
""" |
|
|
|
dataset_path = cfg.dataset.data_dir_prefix |
|
rtg_scale = cfg.dataset.rtg_scale |
|
self.context_len = cfg.dataset.context_len |
|
self.env_type = cfg.dataset.env_type |
|
|
|
if 'hdf5' in dataset_path: |
|
try: |
|
import h5py |
|
import collections |
|
except ImportError: |
|
import sys |
|
logging.warning("not found h5py package, please install it trough `pip install h5py ") |
|
sys.exit(1) |
|
dataset = h5py.File(dataset_path, 'r') |
|
|
|
N = dataset['rewards'].shape[0] |
|
data_ = collections.defaultdict(list) |
|
|
|
use_timeouts = False |
|
if 'timeouts' in dataset: |
|
use_timeouts = True |
|
|
|
episode_step = 0 |
|
paths = [] |
|
for i in range(N): |
|
done_bool = bool(dataset['terminals'][i]) |
|
if use_timeouts: |
|
final_timestep = dataset['timeouts'][i] |
|
else: |
|
final_timestep = (episode_step == 1000 - 1) |
|
for k in ['observations', 'actions', 'rewards', 'terminals']: |
|
data_[k].append(dataset[k][i]) |
|
if done_bool or final_timestep: |
|
episode_step = 0 |
|
episode_data = {} |
|
for k in data_: |
|
episode_data[k] = np.array(data_[k]) |
|
paths.append(episode_data) |
|
data_ = collections.defaultdict(list) |
|
episode_step += 1 |
|
|
|
self.trajectories = paths |
|
|
|
|
|
states = [] |
|
for traj in self.trajectories: |
|
traj_len = traj['observations'].shape[0] |
|
states.append(traj['observations']) |
|
|
|
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale |
|
|
|
|
|
states = np.concatenate(states, axis=0) |
|
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 |
|
|
|
|
|
for traj in self.trajectories: |
|
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std |
|
|
|
elif 'pkl' in dataset_path: |
|
if 'dqn' in dataset_path: |
|
|
|
with open(dataset_path, 'rb') as f: |
|
self.trajectories = pickle.load(f) |
|
|
|
if isinstance(self.trajectories[0], list): |
|
|
|
trajectories_tmp = [] |
|
|
|
original_keys = ['obs', 'next_obs', 'action', 'reward'] |
|
keys = ['observations', 'next_observations', 'actions', 'rewards'] |
|
trajectories_tmp = [ |
|
{ |
|
key: np.stack( |
|
[ |
|
self.trajectories[eps_index][transition_index][o_key] |
|
for transition_index in range(len(self.trajectories[eps_index])) |
|
], |
|
axis=0 |
|
) |
|
for key, o_key in zip(keys, original_keys) |
|
} for eps_index in range(len(self.trajectories)) |
|
] |
|
self.trajectories = trajectories_tmp |
|
|
|
states = [] |
|
for traj in self.trajectories: |
|
|
|
states.append(traj['observations']) |
|
|
|
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale |
|
|
|
|
|
states = np.concatenate(states, axis=0) |
|
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 |
|
|
|
|
|
for traj in self.trajectories: |
|
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std |
|
else: |
|
|
|
with open(dataset_path, 'rb') as f: |
|
self.trajectories = pickle.load(f) |
|
|
|
states = [] |
|
for traj in self.trajectories: |
|
states.append(traj['observations']) |
|
|
|
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale |
|
|
|
|
|
states = np.concatenate(states, axis=0) |
|
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 |
|
|
|
|
|
for traj in self.trajectories: |
|
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std |
|
else: |
|
|
|
obss = [] |
|
actions = [] |
|
returns = [0] |
|
done_idxs = [] |
|
stepwise_returns = [] |
|
|
|
transitions_per_buffer = np.zeros(50, dtype=int) |
|
num_trajectories = 0 |
|
while len(obss) < cfg.dataset.num_steps: |
|
buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] |
|
i = transitions_per_buffer[buffer_num] |
|
frb = FixedReplayBuffer( |
|
data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', |
|
replay_suffix=buffer_num, |
|
observation_shape=(84, 84), |
|
stack_size=4, |
|
update_horizon=1, |
|
gamma=0.99, |
|
observation_dtype=np.uint8, |
|
batch_size=32, |
|
replay_capacity=100000 |
|
) |
|
if frb._loaded_buffers: |
|
done = False |
|
curr_num_transitions = len(obss) |
|
trajectories_to_load = cfg.dataset.trajectories_per_buffer |
|
while not done: |
|
states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ |
|
frb.sample_transition_batch(batch_size=1, indices=[i]) |
|
states = states.transpose((0, 3, 1, 2))[0] |
|
obss.append(states) |
|
actions.append(ac[0]) |
|
stepwise_returns.append(ret[0]) |
|
if terminal[0]: |
|
done_idxs.append(len(obss)) |
|
returns.append(0) |
|
if trajectories_to_load == 0: |
|
done = True |
|
else: |
|
trajectories_to_load -= 1 |
|
returns[-1] += ret[0] |
|
i += 1 |
|
if i >= 100000: |
|
obss = obss[:curr_num_transitions] |
|
actions = actions[:curr_num_transitions] |
|
stepwise_returns = stepwise_returns[:curr_num_transitions] |
|
returns[-1] = 0 |
|
i = transitions_per_buffer[buffer_num] |
|
done = True |
|
num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) |
|
transitions_per_buffer[buffer_num] = i |
|
|
|
actions = np.array(actions) |
|
returns = np.array(returns) |
|
stepwise_returns = np.array(stepwise_returns) |
|
done_idxs = np.array(done_idxs) |
|
|
|
|
|
start_index = 0 |
|
rtg = np.zeros_like(stepwise_returns) |
|
for i in done_idxs: |
|
i = int(i) |
|
curr_traj_returns = stepwise_returns[start_index:i] |
|
for j in range(i - 1, start_index - 1, -1): |
|
rtg_j = curr_traj_returns[j - start_index:i - start_index] |
|
rtg[j] = sum(rtg_j) |
|
start_index = i |
|
|
|
|
|
start_index = 0 |
|
timesteps = np.zeros(len(actions) + 1, dtype=int) |
|
for i in done_idxs: |
|
i = int(i) |
|
timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) |
|
start_index = i + 1 |
|
|
|
self.obss = obss |
|
self.actions = actions |
|
self.done_idxs = done_idxs |
|
self.rtgs = rtg |
|
self.timesteps = timesteps |
|
|
|
|
|
def get_max_timestep(self) -> int: |
|
""" |
|
Overview: |
|
Get the max timestep of the dataset. |
|
""" |
|
|
|
return max(self.timesteps) |
|
|
|
def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Overview: |
|
Get the state mean and std of the dataset. |
|
""" |
|
|
|
return deepcopy(self.state_mean), deepcopy(self.state_std) |
|
|
|
def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]: |
|
""" |
|
Overview: |
|
Get the d4rl dataset stats. |
|
Arguments: |
|
- env_d4rl_name (:obj:`str`): The d4rl env name. |
|
""" |
|
|
|
return self.D4RL_DATASET_STATS[env_d4rl_name] |
|
|
|
def __len__(self) -> int: |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
if self.env_type != 'atari': |
|
return len(self.trajectories) |
|
else: |
|
return len(self.obss) - self.context_len |
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
Arguments: |
|
- idx (:obj:`int`): The index of the dataset. |
|
""" |
|
|
|
if self.env_type != 'atari': |
|
traj = self.trajectories[idx] |
|
traj_len = traj['observations'].shape[0] |
|
|
|
if traj_len > self.context_len: |
|
|
|
si = np.random.randint(0, traj_len - self.context_len) |
|
|
|
states = torch.from_numpy(traj['observations'][si:si + self.context_len]) |
|
actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) |
|
returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) |
|
timesteps = torch.arange(start=si, end=si + self.context_len, step=1) |
|
|
|
|
|
traj_mask = torch.ones(self.context_len, dtype=torch.long) |
|
|
|
else: |
|
padding_len = self.context_len - traj_len |
|
|
|
|
|
states = torch.from_numpy(traj['observations']) |
|
states = torch.cat( |
|
[states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 |
|
) |
|
|
|
actions = torch.from_numpy(traj['actions']) |
|
actions = torch.cat( |
|
[actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 |
|
) |
|
|
|
returns_to_go = torch.from_numpy(traj['returns_to_go']) |
|
returns_to_go = torch.cat( |
|
[ |
|
returns_to_go, |
|
torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) |
|
], |
|
dim=0 |
|
) |
|
|
|
timesteps = torch.arange(start=0, end=self.context_len, step=1) |
|
|
|
traj_mask = torch.cat( |
|
[torch.ones(traj_len, dtype=torch.long), |
|
torch.zeros(padding_len, dtype=torch.long)], dim=0 |
|
) |
|
return timesteps, states, actions, returns_to_go, traj_mask |
|
else: |
|
block_size = self.context_len |
|
done_idx = idx + block_size |
|
for i in self.done_idxs: |
|
if i > idx: |
|
done_idx = min(int(i), done_idx) |
|
break |
|
idx = done_idx - block_size |
|
states = torch.as_tensor( |
|
np.array(self.obss[idx:done_idx]), dtype=torch.float32 |
|
).view(block_size, -1) |
|
states = states / 255. |
|
actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) |
|
rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) |
|
timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1) |
|
traj_mask = torch.ones(self.context_len, dtype=torch.long) |
|
return timesteps, states, actions, rtgs, traj_mask |
|
|
|
|
|
@DATASET_REGISTRY.register('d4rl_diffuser') |
|
class D4RLDiffuserDataset(Dataset): |
|
""" |
|
Overview: |
|
D4RL diffuser dataset, which is used for offline RL algorithms. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
""" |
|
|
|
def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None: |
|
""" |
|
Overview: |
|
Initialization method of D4RLDiffuserDataset. |
|
Arguments: |
|
- dataset_path (:obj:`str`): The dataset path. |
|
- context_len (:obj:`int`): The length of the context. |
|
- rtg_scale (:obj:`float`): The scale of the returns to go. |
|
""" |
|
|
|
self.context_len = context_len |
|
|
|
|
|
with open(dataset_path, 'rb') as f: |
|
self.trajectories = pickle.load(f) |
|
|
|
if isinstance(self.trajectories[0], list): |
|
|
|
trajectories_tmp = [] |
|
|
|
original_keys = ['obs', 'next_obs', 'action', 'reward'] |
|
keys = ['observations', 'next_observations', 'actions', 'rewards'] |
|
for key, o_key in zip(keys, original_keys): |
|
trajectories_tmp = [ |
|
{ |
|
key: np.stack( |
|
[ |
|
self.trajectories[eps_index][transition_index][o_key] |
|
for transition_index in range(len(self.trajectories[eps_index])) |
|
], |
|
axis=0 |
|
) |
|
} for eps_index in range(len(self.trajectories)) |
|
] |
|
self.trajectories = trajectories_tmp |
|
|
|
states = [] |
|
for traj in self.trajectories: |
|
traj_len = traj['observations'].shape[0] |
|
states.append(traj['observations']) |
|
|
|
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale |
|
|
|
|
|
states = np.concatenate(states, axis=0) |
|
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 |
|
|
|
|
|
for traj in self.trajectories: |
|
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std |
|
|
|
|
|
class FixedReplayBuffer(object): |
|
""" |
|
Overview: |
|
Object composed of a list of OutofGraphReplayBuffers. |
|
Interfaces: |
|
``__init__``, ``get_transition_elements``, ``sample_transition_batch`` |
|
""" |
|
|
|
def __init__(self, data_dir, replay_suffix, *args, **kwargs): |
|
""" |
|
Overview: |
|
Initialize the FixedReplayBuffer class. |
|
Arguments: |
|
- data_dir (:obj:`str`): log Directory from which to load the replay buffer. |
|
- replay_suffix (:obj:`int`): If not None, then only load the replay buffer \ |
|
corresponding to the specific suffix in data directory. |
|
- args (:obj:`list`): Arbitrary extra arguments. |
|
- kwargs (:obj:`dict`): Arbitrary keyword arguments. |
|
|
|
""" |
|
|
|
self._args = args |
|
self._kwargs = kwargs |
|
self._data_dir = data_dir |
|
self._loaded_buffers = False |
|
self.add_count = np.array(0) |
|
self._replay_suffix = replay_suffix |
|
if not self._loaded_buffers: |
|
if replay_suffix is not None: |
|
assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' |
|
self.load_single_buffer(replay_suffix) |
|
else: |
|
pass |
|
|
|
|
|
def load_single_buffer(self, suffix): |
|
""" |
|
Overview: |
|
Load a single replay buffer. |
|
Arguments: |
|
- suffix (:obj:`int`): The suffix of the replay buffer. |
|
""" |
|
|
|
replay_buffer = self._load_buffer(suffix) |
|
if replay_buffer is not None: |
|
self._replay_buffers = [replay_buffer] |
|
self.add_count = replay_buffer.add_count |
|
self._num_replay_buffers = 1 |
|
self._loaded_buffers = True |
|
|
|
def _load_buffer(self, suffix): |
|
""" |
|
Overview: |
|
Loads a OutOfGraphReplayBuffer replay buffer. |
|
Arguments: |
|
- suffix (:obj:`int`): The suffix of the replay buffer. |
|
""" |
|
|
|
try: |
|
from dopamine.replay_memory import circular_replay_buffer |
|
STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX |
|
|
|
replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs) |
|
replay_buffer.load(self._data_dir, suffix) |
|
|
|
return replay_buffer |
|
|
|
except: |
|
raise ('can not load') |
|
|
|
def get_transition_elements(self): |
|
""" |
|
Overview: |
|
Returns the transition elements. |
|
""" |
|
|
|
return self._replay_buffers[0].get_transition_elements() |
|
|
|
def sample_transition_batch(self, batch_size=None, indices=None): |
|
""" |
|
Overview: |
|
Returns a batch of transitions (including any extra contents). |
|
Arguments: |
|
- batch_size (:obj:`int`): The batch size. |
|
- indices (:obj:`list`): The indices of the batch. |
|
""" |
|
|
|
buffer_index = np.random.randint(self._num_replay_buffers) |
|
return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices) |
|
|
|
|
|
class PCDataset(Dataset): |
|
""" |
|
Overview: |
|
Dataset for Procedure Cloning. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
""" |
|
|
|
def __init__(self, all_data): |
|
""" |
|
Overview: |
|
Initialization method of PCDataset. |
|
Arguments: |
|
- all_data (:obj:`tuple`): The tuple of all data. |
|
""" |
|
|
|
self._data = all_data |
|
|
|
def __getitem__(self, item): |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
Arguments: |
|
- item (:obj:`int`): The index of the dataset. |
|
""" |
|
|
|
return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]} |
|
|
|
def __len__(self): |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return self._data[0].shape[0] |
|
|
|
|
|
def load_bfs_datasets(train_seeds=1, test_seeds=5): |
|
""" |
|
Overview: |
|
Load BFS datasets. |
|
Arguments: |
|
- train_seeds (:obj:`int`): The number of train seeds. |
|
- test_seeds (:obj:`int`): The number of test seeds. |
|
""" |
|
|
|
from dizoo.maze.envs import Maze |
|
|
|
def load_env(seed): |
|
ccc = easydict.EasyDict({'size': 16}) |
|
e = Maze(ccc) |
|
e.seed(seed) |
|
e.reset() |
|
return e |
|
|
|
envs = [load_env(i) for i in range(train_seeds + test_seeds)] |
|
|
|
observations_train = [] |
|
observations_test = [] |
|
bfs_input_maps_train = [] |
|
bfs_input_maps_test = [] |
|
bfs_output_maps_train = [] |
|
bfs_output_maps_test = [] |
|
for idx, env in enumerate(envs): |
|
if idx < train_seeds: |
|
observations = observations_train |
|
bfs_input_maps = bfs_input_maps_train |
|
bfs_output_maps = bfs_output_maps_train |
|
else: |
|
observations = observations_test |
|
bfs_input_maps = bfs_input_maps_test |
|
bfs_output_maps = bfs_output_maps_test |
|
|
|
start_obs = env.process_states(env._get_obs(), env.get_maze_map()) |
|
_, track_back = get_vi_sequence(env, start_obs) |
|
env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0) |
|
|
|
for i in range(env_observations.shape[0]): |
|
bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) |
|
bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long) |
|
|
|
for j in range(bfs_sequence.shape[0]): |
|
bfs_input_maps.append(torch.from_numpy(bfs_input_map)) |
|
bfs_output_maps.append(torch.from_numpy(bfs_sequence[j])) |
|
observations.append(env_observations[i]) |
|
bfs_input_map = bfs_sequence[j] |
|
|
|
train_data = PCDataset( |
|
( |
|
torch.stack(observations_train, dim=0), |
|
torch.stack(bfs_input_maps_train, dim=0), |
|
torch.stack(bfs_output_maps_train, dim=0), |
|
) |
|
) |
|
test_data = PCDataset( |
|
( |
|
torch.stack(observations_test, dim=0), |
|
torch.stack(bfs_input_maps_test, dim=0), |
|
torch.stack(bfs_output_maps_test, dim=0), |
|
) |
|
) |
|
|
|
return train_data, test_data |
|
|
|
|
|
@DATASET_REGISTRY.register('bco') |
|
class BCODataset(Dataset): |
|
""" |
|
Overview: |
|
Dataset for Behavioral Cloning from Observation. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
Properties: |
|
- obs (:obj:`np.ndarray`): The observation array. |
|
- action (:obj:`np.ndarray`): The action array. |
|
""" |
|
|
|
def __init__(self, data=None): |
|
""" |
|
Overview: |
|
Initialization method of BCODataset. |
|
Arguments: |
|
- data (:obj:`dict`): The data dict. |
|
""" |
|
|
|
if data is None: |
|
raise ValueError('Dataset can not be empty!') |
|
else: |
|
self._data = data |
|
|
|
def __len__(self): |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return len(self._data['obs']) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
Arguments: |
|
- idx (:obj:`int`): The index of the dataset. |
|
""" |
|
|
|
return {k: self._data[k][idx] for k in self._data.keys()} |
|
|
|
@property |
|
def obs(self): |
|
""" |
|
Overview: |
|
Get the observation array. |
|
""" |
|
|
|
return self._data['obs'] |
|
|
|
@property |
|
def action(self): |
|
""" |
|
Overview: |
|
Get the action array. |
|
""" |
|
|
|
return self._data['action'] |
|
|
|
|
|
@DATASET_REGISTRY.register('diffuser_traj') |
|
class SequenceDataset(torch.utils.data.Dataset): |
|
""" |
|
Overview: |
|
Dataset for diffuser. |
|
Interfaces: |
|
``__init__``, ``__len__``, ``__getitem__`` |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
Overview: |
|
Initialization method of SequenceDataset. |
|
Arguments: |
|
- cfg (:obj:`dict`): The config dict. |
|
""" |
|
|
|
import gym |
|
|
|
env_id = cfg.env.env_id |
|
data_path = cfg.policy.collect.get('data_path', None) |
|
env = gym.make(env_id) |
|
|
|
dataset = env.get_dataset() |
|
|
|
self.returns_scale = cfg.env.returns_scale |
|
self.horizon = cfg.env.horizon |
|
self.max_path_length = cfg.env.max_path_length |
|
self.discount = cfg.policy.learn.discount_factor |
|
self.discounts = self.discount ** np.arange(self.max_path_length)[:, None] |
|
self.use_padding = cfg.env.use_padding |
|
self.include_returns = cfg.env.include_returns |
|
self.env_id = cfg.env.env_id |
|
itr = self.sequence_dataset(env, dataset) |
|
self.n_episodes = 0 |
|
|
|
fields = {} |
|
for k in dataset.keys(): |
|
if 'metadata' in k: |
|
continue |
|
fields[k] = [] |
|
fields['path_lengths'] = [] |
|
|
|
for i, episode in enumerate(itr): |
|
path_length = len(episode['observations']) |
|
assert path_length <= self.max_path_length |
|
fields['path_lengths'].append(path_length) |
|
for key, val in episode.items(): |
|
if key not in fields: |
|
fields[key] = [] |
|
if val.ndim < 2: |
|
val = np.expand_dims(val, axis=-1) |
|
shape = (self.max_path_length, val.shape[-1]) |
|
arr = np.zeros(shape, dtype=np.float32) |
|
arr[:path_length] = val |
|
fields[key].append(arr) |
|
if episode['terminals'].any() and cfg.env.termination_penalty and 'timeouts' in episode: |
|
assert not episode['timeouts'].any(), 'Penalized a timeout episode for early termination' |
|
fields['rewards'][-1][path_length - 1] += cfg.env.termination_penalty |
|
self.n_episodes += 1 |
|
|
|
for k in fields.keys(): |
|
fields[k] = np.array(fields[k]) |
|
|
|
self.normalizer = DatasetNormalizer(fields, cfg.policy.normalizer, path_lengths=fields['path_lengths']) |
|
self.indices = self.make_indices(fields['path_lengths'], self.horizon) |
|
|
|
self.observation_dim = cfg.env.obs_dim |
|
self.action_dim = cfg.env.action_dim |
|
self.fields = fields |
|
self.normalize() |
|
self.normed = False |
|
if cfg.env.normed: |
|
self.vmin, self.vmax = self._get_bounds() |
|
self.normed = True |
|
|
|
|
|
|
|
|
|
def sequence_dataset(self, env, dataset=None): |
|
""" |
|
Overview: |
|
Sequence the dataset. |
|
Arguments: |
|
- env (:obj:`gym.Env`): The gym env. |
|
""" |
|
|
|
import collections |
|
N = dataset['rewards'].shape[0] |
|
if 'maze2d' in env.spec.id: |
|
dataset = self.maze2d_set_terminals(env, dataset) |
|
data_ = collections.defaultdict(list) |
|
|
|
|
|
|
|
use_timeouts = 'timeouts' in dataset |
|
|
|
episode_step = 0 |
|
for i in range(N): |
|
done_bool = bool(dataset['terminals'][i]) |
|
if use_timeouts: |
|
final_timestep = dataset['timeouts'][i] |
|
else: |
|
final_timestep = (episode_step == env._max_episode_steps - 1) |
|
|
|
for k in dataset: |
|
if 'metadata' in k: |
|
continue |
|
data_[k].append(dataset[k][i]) |
|
|
|
if done_bool or final_timestep: |
|
episode_step = 0 |
|
episode_data = {} |
|
for k in data_: |
|
episode_data[k] = np.array(data_[k]) |
|
if 'maze2d' in env.spec.id: |
|
episode_data = self.process_maze2d_episode(episode_data) |
|
yield episode_data |
|
data_ = collections.defaultdict(list) |
|
|
|
episode_step += 1 |
|
|
|
def maze2d_set_terminals(self, env, dataset): |
|
""" |
|
Overview: |
|
Set the terminals for maze2d. |
|
Arguments: |
|
- env (:obj:`gym.Env`): The gym env. |
|
- dataset (:obj:`dict`): The dataset dict. |
|
""" |
|
|
|
goal = env.get_target() |
|
threshold = 0.5 |
|
|
|
xy = dataset['observations'][:, :2] |
|
distances = np.linalg.norm(xy - goal, axis=-1) |
|
at_goal = distances < threshold |
|
timeouts = np.zeros_like(dataset['timeouts']) |
|
|
|
|
|
|
|
|
|
timeouts[:-1] = at_goal[:-1] * ~at_goal[1:] |
|
|
|
timeout_steps = np.where(timeouts)[0] |
|
path_lengths = timeout_steps[1:] - timeout_steps[:-1] |
|
|
|
print( |
|
f'[ utils/preprocessing ] Segmented {env.spec.id} | {len(path_lengths)} paths | ' |
|
f'min length: {path_lengths.min()} | max length: {path_lengths.max()}' |
|
) |
|
|
|
dataset['timeouts'] = timeouts |
|
return dataset |
|
|
|
def process_maze2d_episode(self, episode): |
|
""" |
|
Overview: |
|
Process the maze2d episode, adds in `next_observations` field to episode. |
|
Arguments: |
|
- episode (:obj:`dict`): The episode dict. |
|
""" |
|
|
|
assert 'next_observations' not in episode |
|
length = len(episode['observations']) |
|
next_observations = episode['observations'][1:].copy() |
|
for key, val in episode.items(): |
|
episode[key] = val[:-1] |
|
episode['next_observations'] = next_observations |
|
return episode |
|
|
|
def normalize(self, keys=['observations', 'actions']): |
|
""" |
|
Overview: |
|
Normalize the dataset, normalize fields that will be predicted by the diffusion model |
|
Arguments: |
|
- keys (:obj:`list`): The list of keys. |
|
""" |
|
|
|
for key in keys: |
|
array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1) |
|
normed = self.normalizer.normalize(array, key) |
|
self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1) |
|
|
|
def make_indices(self, path_lengths, horizon): |
|
""" |
|
Overview: |
|
Make indices for sampling from dataset. Each index maps to a datapoint. |
|
Arguments: |
|
- path_lengths (:obj:`np.ndarray`): The path length array. |
|
- horizon (:obj:`int`): The horizon. |
|
""" |
|
|
|
indices = [] |
|
for i, path_length in enumerate(path_lengths): |
|
max_start = min(path_length - 1, self.max_path_length - horizon) |
|
if not self.use_padding: |
|
max_start = min(max_start, path_length - horizon) |
|
for start in range(max_start): |
|
end = start + horizon |
|
indices.append((i, start, end)) |
|
indices = np.array(indices) |
|
return indices |
|
|
|
def get_conditions(self, observations): |
|
""" |
|
Overview: |
|
Get the conditions on current observation for planning. |
|
Arguments: |
|
- observations (:obj:`np.ndarray`): The observation array. |
|
""" |
|
|
|
if 'maze2d' in self.env_id: |
|
return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]} |
|
else: |
|
return {'condition_id': [0], 'condition_val': [observations[0]]} |
|
|
|
def __len__(self): |
|
""" |
|
Overview: |
|
Get the length of the dataset. |
|
""" |
|
|
|
return len(self.indices) |
|
|
|
def _get_bounds(self): |
|
""" |
|
Overview: |
|
Get the bounds of the dataset. |
|
""" |
|
|
|
print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True) |
|
vmin = np.inf |
|
vmax = -np.inf |
|
for i in range(len(self.indices)): |
|
value = self.__getitem__(i)['returns'].item() |
|
vmin = min(value, vmin) |
|
vmax = max(value, vmax) |
|
print('✓') |
|
return vmin, vmax |
|
|
|
def normalize_value(self, value): |
|
""" |
|
Overview: |
|
Normalize the value. |
|
Arguments: |
|
- value (:obj:`np.ndarray`): The value array. |
|
""" |
|
|
|
|
|
normed = (value - self.vmin) / (self.vmax - self.vmin) |
|
|
|
normed = normed * 2 - 1 |
|
return normed |
|
|
|
def __getitem__(self, idx, eps=1e-4): |
|
""" |
|
Overview: |
|
Get the item of the dataset. |
|
Arguments: |
|
- idx (:obj:`int`): The index of the dataset. |
|
- eps (:obj:`float`): The epsilon. |
|
""" |
|
|
|
path_ind, start, end = self.indices[idx] |
|
|
|
observations = self.fields['normed_observations'][path_ind, start:end] |
|
actions = self.fields['normed_actions'][path_ind, start:end] |
|
done = self.fields['terminals'][path_ind, start:end] |
|
|
|
|
|
trajectories = np.concatenate([actions, observations], axis=-1) |
|
|
|
if self.include_returns: |
|
rewards = self.fields['rewards'][path_ind, start:] |
|
discounts = self.discounts[:len(rewards)] |
|
returns = (discounts * rewards).sum() |
|
if self.normed: |
|
returns = self.normalize_value(returns) |
|
returns = np.array([returns / self.returns_scale], dtype=np.float32) |
|
batch = { |
|
'trajectories': trajectories, |
|
'returns': returns, |
|
'done': done, |
|
'action': actions, |
|
} |
|
else: |
|
batch = { |
|
'trajectories': trajectories, |
|
'done': done, |
|
'action': actions, |
|
} |
|
|
|
batch.update(self.get_conditions(observations)) |
|
return batch |
|
|
|
|
|
def hdf5_save(exp_data, expert_data_path): |
|
""" |
|
Overview: |
|
Save the data to hdf5. |
|
""" |
|
|
|
try: |
|
import h5py |
|
except ImportError: |
|
import sys |
|
logging.warning("not found h5py package, please install it trough 'pip install h5py' ") |
|
sys.exit(1) |
|
dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w') |
|
dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip') |
|
dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip') |
|
dataset.create_dataset('reward', data=np.array([d['reward'].numpy() for d in exp_data]), compression='gzip') |
|
dataset.create_dataset('done', data=np.array([d['done'] for d in exp_data]), compression='gzip') |
|
dataset.create_dataset('next_obs', data=np.array([d['next_obs'].numpy() for d in exp_data]), compression='gzip') |
|
|
|
|
|
def naive_save(exp_data, expert_data_path): |
|
""" |
|
Overview: |
|
Save the data to pickle. |
|
""" |
|
|
|
with open(expert_data_path, 'wb') as f: |
|
pickle.dump(exp_data, f) |
|
|
|
|
|
def offline_data_save_type(exp_data, expert_data_path, data_type='naive'): |
|
""" |
|
Overview: |
|
Save the offline data. |
|
""" |
|
|
|
globals()[data_type + '_save'](exp_data, expert_data_path) |
|
|
|
|
|
def create_dataset(cfg, **kwargs) -> Dataset: |
|
""" |
|
Overview: |
|
Create dataset. |
|
""" |
|
|
|
cfg = EasyDict(cfg) |
|
import_module(cfg.get('import_names', [])) |
|
return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs) |
|
|