Spaces:
Running
on
Zero
Running
on
Zero
import collections | |
import datetime | |
import io | |
import pathlib | |
import uuid | |
import os | |
import numpy as np | |
from gym.spaces import Dict | |
import random | |
from torch.utils.data import IterableDataset, DataLoader | |
import torch | |
import tools.utils as utils | |
import traceback | |
from pathlib import Path | |
from tqdm import tqdm | |
SIG_FAILURE = -1 | |
def get_length(filename): | |
if "-" in str(filename): | |
length = int(str(filename).split('-')[-1]) | |
else: | |
length = int(str(filename).split('_')[-1]) | |
return length | |
def get_idx(filename): | |
if "-" in str(filename): | |
length = int(str(filename).split('-')[0]) | |
else: | |
length = int(str(filename).split('_')[0]) | |
return length | |
def on_fn(): return collections.defaultdict(list) # this function is to avoid lambdas | |
class ReplayBuffer(IterableDataset): | |
def __init__( | |
self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0, | |
prioritize_ends=False, device='cuda', load_first=False, save_episodes=True, ignore_extra_keys=False, load_recursive=False, min_t_sampling=0, **kwargs): | |
self._directory = pathlib.Path(directory).expanduser() | |
self._directory.mkdir(parents=True, exist_ok=True) | |
self._capacity = capacity | |
self._ongoing = ongoing | |
self._minlen = minlen | |
self._maxlen = maxlen | |
self._prioritize_ends = prioritize_ends | |
self._ignore_extra_keys = ignore_extra_keys | |
self._min_t_sampling = min_t_sampling | |
# self._random = np.random.RandomState() | |
# filename -> key -> value_sequence | |
self._save_episodes = save_episodes | |
self._last_added_idx = 0 | |
self._episode_lens = np.array([]) | |
self._complete_eps = {} | |
self._data_specs = data_specs | |
self._meta_specs = meta_specs | |
for spec_group in [data_specs, meta_specs]: | |
for spec in spec_group: | |
if type(spec) in [dict, Dict]: | |
for k,v in spec.items(): | |
self._complete_eps[k] = [] | |
else: | |
self._complete_eps[spec.name] = [] | |
# load episodes | |
if type(directory) == str: | |
directory = Path(directory) | |
self._loaded_episodes = 0 | |
self._loaded_steps = 0 | |
for f in tqdm(load_filenames(self._directory, capacity, minlen, load_first=load_first, load_recursive=load_recursive)): | |
self.store_episode(filename=f) | |
try: | |
self._total_episodes, self._total_steps = count_episodes(directory) | |
except: | |
print("Couldn't count episodes") | |
print("Loaded episodes: ", self._loaded_episodes) | |
print("Loaded steps: ", self._loaded_steps) | |
self._total_episodes, self._total_steps = self._loaded_episodes, self._loaded_steps | |
# worker -> key -> value_sequence | |
self._length = length | |
self._ongoing_eps = collections.defaultdict(on_fn) | |
self.device = device | |
try: | |
assert self._minlen <= self._length <= self._maxlen | |
except: | |
print("Sampling sequences with fixed length ", length) | |
self._minlen = self._maxlen = self._length = length | |
def __len__(self): | |
return self._total_steps | |
def preallocate_memory(self, max_size): | |
self._preallocated_mem = collections.defaultdict(list) | |
for spec in self._data_specs: | |
if type(spec) in [dict, Dict]: | |
for k,v in spec.items(): | |
for _ in range(max_size): | |
self._preallocated_mem[k].append(np.empty(list(v.shape), v.dtype)) | |
self._preallocated_mem[k][-1].fill(0.) | |
else: | |
for _ in range(max_size): | |
self._preallocated_mem[spec.name].append(np.empty(list(v.shape), v.dtype)) | |
self._preallocated_mem[spec.name][-1].fill(0.) | |
def stats(self): | |
return { | |
'total_steps': self._total_steps, | |
'total_episodes': self._total_episodes, | |
'loaded_steps': self._loaded_steps, | |
'loaded_episodes': self._loaded_episodes, | |
} | |
def add(self, time_step, meta, idx=0): | |
### Useful if there was any failure in the environment | |
if time_step == SIG_FAILURE: | |
episode = self._ongoing_eps[idx] | |
episode.clear() | |
print("Discarding episode from process", idx) | |
return | |
#### | |
episode = self._ongoing_eps[idx] | |
def add_to_episode(name, data, spec): | |
value = data[name] | |
if np.isscalar(value): | |
value = np.full(spec.shape, value, spec.dtype) | |
assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })" | |
### Deallocate preallocated memory | |
if getattr(self, '_preallocated_mem', False): | |
if len(self._preallocated_mem[name]) > 0: | |
tmp = self._preallocated_mem[name].pop() | |
del tmp | |
else: | |
# Out of pre-allocated memory | |
del self._preallocated_mem | |
### | |
episode[name].append(value) | |
for spec in self._data_specs: | |
if type(spec) in [dict, Dict]: | |
for k,v in spec.items(): | |
add_to_episode(k, time_step, v) | |
else: | |
add_to_episode(spec.name, time_step, spec) | |
for spec in self._meta_specs: | |
if type(spec) in [dict, Dict]: | |
for k,v in spec.items(): | |
add_to_episode(k, meta, v) | |
else: | |
add_to_episode(spec.name, meta, spec) | |
if type(time_step) in [dict, Dict]: | |
if time_step['is_last']: | |
self.add_episode(episode) | |
episode.clear() | |
else: | |
if time_step.last(): | |
self.add_episode(episode) | |
episode.clear() | |
def add_episode(self, episode): | |
length = eplen(episode) | |
if length < self._minlen: | |
print(f'Skipping short episode of length {length}.') | |
return | |
self._total_steps += length | |
self._total_episodes += 1 | |
episode = {key: convert(value) for key, value in episode.items()} | |
if self._save_episodes: | |
filename = self.save_episode(self._directory, episode) | |
self.store_episode(episode=episode) | |
def store_episode(self, filename=None, episode=None, run_checks=True): | |
if filename is not None: | |
episode = load_episode(filename) | |
if len(episode['reward'].shape) == 1: | |
episode['reward'] = episode['reward'].reshape(-1, 1) | |
if 'discount' not in episode: | |
episode['discount'] = (1 - episode['is_terminal']).reshape(-1, 1).astype(np.float32) | |
# | |
if run_checks: | |
for spec_set in [self._data_specs, self._meta_specs]: | |
for spec in spec_set: | |
if type(spec) in [dict, Dict]: | |
for k,v in spec.items(): | |
value = episode[k][0] | |
assert v.shape == value.shape and v.dtype == value.dtype, f"for ({k}) expected {v.dtype, v.shape, }), received ({value.dtype, value.shape, })" | |
else: | |
value = episode[spec.name][0] | |
assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({spec.name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })" | |
if not episode: | |
return False | |
length = eplen(episode) | |
if run_checks: | |
for k in episode: | |
assert len(episode[k]) == length, f'Found {episode[k].shape} VS eplen: {length}' | |
# Enforce limit | |
while self._loaded_steps + length > self._capacity: | |
for k in self._complete_eps: | |
self._complete_eps[k].pop(0) | |
removed_len, self._episode_lens = self._episode_lens[0], self._episode_lens[1:] | |
self._loaded_steps -= removed_len | |
self._loaded_episodes -= 1 | |
# add episode | |
for k,v in episode.items(): | |
if k not in self._complete_eps: | |
if self._ignore_extra_keys: continue | |
else: raise KeyError("Extra key ", k) | |
self._complete_eps[k].append(v) | |
self._episode_lens = np.append(self._episode_lens, length) | |
self._loaded_steps += length | |
self._loaded_episodes += 1 | |
return True | |
def __iter__(self): | |
while True: | |
sequences, batch_size, batch_length = self._loaded_episodes, self.batch_size, self._length | |
b_indices = np.random.randint(0, sequences, size=batch_size) | |
t_indices = np.random.randint(np.zeros(batch_size) + self._min_t_sampling, self._episode_lens[b_indices]-batch_length+1, size=batch_size) | |
t_ranges = np.repeat( np.expand_dims(np.arange(0, batch_length,), 0), batch_size, axis=0) + np.expand_dims(t_indices, 1) | |
chunk = {} | |
for k in self._complete_eps: | |
chunk[k] = np.stack([self._complete_eps[k][b][t] for b,t in zip(b_indices, t_ranges)]) | |
for k in chunk: | |
chunk[k] = torch.as_tensor(chunk[k], device=self.device) | |
yield chunk | |
def save_episode(self, directory, episode): | |
idx = self._total_episodes | |
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') | |
identifier = str(uuid.uuid4().hex) | |
length = eplen(episode) | |
filename = directory / f'{idx}-{timestamp}-{identifier}-{length}.npz' | |
with io.BytesIO() as f1: | |
np.savez_compressed(f1, **episode) | |
f1.seek(0) | |
with filename.open('wb') as f2: | |
f2.write(f1.read()) | |
return filename | |
def load_episode(filename): | |
try: | |
with filename.open('rb') as f: | |
episode = np.load(f, allow_pickle=True) | |
episode = {k: episode[k] for k in episode.keys()} | |
except Exception as e: | |
print(f'Could not load episode {str(filename)}: {e}') | |
return False | |
return episode | |
def count_episodes(directory): | |
filenames = list(directory.glob('*.npz')) | |
num_episodes = len(filenames) | |
if num_episodes == 0 : return 0, 0 | |
if len(filenames) > 0 and "-" in str(filenames[0]): | |
num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames) | |
last_episode = sorted(list(int(n.stem.split('-')[0]) for n in filenames))[-1] | |
else: | |
num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames) | |
last_episode = sorted(list(int(n.stem.split('_')[0]) for n in filenames))[-1] | |
return last_episode, num_steps | |
def load_filenames(directory, capacity=None, minlen=1, load_first=False, load_recursive=False): | |
# The returned directory from filenames to episodes is guaranteed to be in | |
# temporally sorted order. | |
if load_recursive: | |
filenames = sorted(directory.glob('**/*.npz')) | |
else: | |
filenames = sorted(directory.glob('*.npz')) | |
if capacity: | |
num_steps = 0 | |
num_episodes = 0 | |
ordered_filenames = filenames if load_first else reversed(filenames) | |
for filename in ordered_filenames: | |
if "-" in str(filename): | |
length = int(str(filename).split('-')[-1][:-4]) | |
else: | |
length = int(str(filename).split('_')[-1][:-4]) | |
num_steps += length | |
num_episodes += 1 | |
if num_steps >= capacity: | |
break | |
if load_first: | |
filenames = filenames[:num_episodes] | |
else: | |
filenames = filenames[-num_episodes:] | |
return filenames | |
def convert(value): | |
value = np.array(value) | |
if np.issubdtype(value.dtype, np.floating): | |
return value.astype(np.float32) | |
elif np.issubdtype(value.dtype, np.signedinteger): | |
return value.astype(np.int32) | |
elif np.issubdtype(value.dtype, np.uint8): | |
return value.astype(np.uint8) | |
return value | |
def eplen(episode): | |
return len(episode['action']) | |
def make_replay_loader(buffer, batch_size,): | |
buffer.batch_size = batch_size | |
return DataLoader(buffer, | |
batch_size=None, | |
# NOTE: do not use any workers, | |
# as they don't get copies of the replay buffer (requires different implementation) | |
) |