genrl / tools /replay.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
No virus
11.4 kB
import collections
import datetime
import io
import pathlib
import uuid
import os
import numpy as np
from gym.spaces import Dict
import random
from torch.utils.data import IterableDataset, DataLoader
import torch
import tools.utils as utils
import traceback
from pathlib import Path
from tqdm import tqdm
SIG_FAILURE = -1
def get_length(filename):
if "-" in str(filename):
length = int(str(filename).split('-')[-1])
else:
length = int(str(filename).split('_')[-1])
return length
def get_idx(filename):
if "-" in str(filename):
length = int(str(filename).split('-')[0])
else:
length = int(str(filename).split('_')[0])
return length
def on_fn(): return collections.defaultdict(list) # this function is to avoid lambdas
class ReplayBuffer(IterableDataset):
def __init__(
self, data_specs, meta_specs, directory, length=20, capacity=0, ongoing=False, minlen=1, maxlen=0,
prioritize_ends=False, device='cuda', load_first=False, save_episodes=True, ignore_extra_keys=False, load_recursive=False, min_t_sampling=0, **kwargs):
self._directory = pathlib.Path(directory).expanduser()
self._directory.mkdir(parents=True, exist_ok=True)
self._capacity = capacity
self._ongoing = ongoing
self._minlen = minlen
self._maxlen = maxlen
self._prioritize_ends = prioritize_ends
self._ignore_extra_keys = ignore_extra_keys
self._min_t_sampling = min_t_sampling
# self._random = np.random.RandomState()
# filename -> key -> value_sequence
self._save_episodes = save_episodes
self._last_added_idx = 0
self._episode_lens = np.array([])
self._complete_eps = {}
self._data_specs = data_specs
self._meta_specs = meta_specs
for spec_group in [data_specs, meta_specs]:
for spec in spec_group:
if type(spec) in [dict, Dict]:
for k,v in spec.items():
self._complete_eps[k] = []
else:
self._complete_eps[spec.name] = []
# load episodes
if type(directory) == str:
directory = Path(directory)
self._loaded_episodes = 0
self._loaded_steps = 0
for f in tqdm(load_filenames(self._directory, capacity, minlen, load_first=load_first, load_recursive=load_recursive)):
self.store_episode(filename=f)
try:
self._total_episodes, self._total_steps = count_episodes(directory)
except:
print("Couldn't count episodes")
print("Loaded episodes: ", self._loaded_episodes)
print("Loaded steps: ", self._loaded_steps)
self._total_episodes, self._total_steps = self._loaded_episodes, self._loaded_steps
# worker -> key -> value_sequence
self._length = length
self._ongoing_eps = collections.defaultdict(on_fn)
self.device = device
try:
assert self._minlen <= self._length <= self._maxlen
except:
print("Sampling sequences with fixed length ", length)
self._minlen = self._maxlen = self._length = length
def __len__(self):
return self._total_steps
def preallocate_memory(self, max_size):
self._preallocated_mem = collections.defaultdict(list)
for spec in self._data_specs:
if type(spec) in [dict, Dict]:
for k,v in spec.items():
for _ in range(max_size):
self._preallocated_mem[k].append(np.empty(list(v.shape), v.dtype))
self._preallocated_mem[k][-1].fill(0.)
else:
for _ in range(max_size):
self._preallocated_mem[spec.name].append(np.empty(list(v.shape), v.dtype))
self._preallocated_mem[spec.name][-1].fill(0.)
@property
def stats(self):
return {
'total_steps': self._total_steps,
'total_episodes': self._total_episodes,
'loaded_steps': self._loaded_steps,
'loaded_episodes': self._loaded_episodes,
}
def add(self, time_step, meta, idx=0):
### Useful if there was any failure in the environment
if time_step == SIG_FAILURE:
episode = self._ongoing_eps[idx]
episode.clear()
print("Discarding episode from process", idx)
return
####
episode = self._ongoing_eps[idx]
def add_to_episode(name, data, spec):
value = data[name]
if np.isscalar(value):
value = np.full(spec.shape, value, spec.dtype)
assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })"
### Deallocate preallocated memory
if getattr(self, '_preallocated_mem', False):
if len(self._preallocated_mem[name]) > 0:
tmp = self._preallocated_mem[name].pop()
del tmp
else:
# Out of pre-allocated memory
del self._preallocated_mem
###
episode[name].append(value)
for spec in self._data_specs:
if type(spec) in [dict, Dict]:
for k,v in spec.items():
add_to_episode(k, time_step, v)
else:
add_to_episode(spec.name, time_step, spec)
for spec in self._meta_specs:
if type(spec) in [dict, Dict]:
for k,v in spec.items():
add_to_episode(k, meta, v)
else:
add_to_episode(spec.name, meta, spec)
if type(time_step) in [dict, Dict]:
if time_step['is_last']:
self.add_episode(episode)
episode.clear()
else:
if time_step.last():
self.add_episode(episode)
episode.clear()
def add_episode(self, episode):
length = eplen(episode)
if length < self._minlen:
print(f'Skipping short episode of length {length}.')
return
self._total_steps += length
self._total_episodes += 1
episode = {key: convert(value) for key, value in episode.items()}
if self._save_episodes:
filename = self.save_episode(self._directory, episode)
self.store_episode(episode=episode)
def store_episode(self, filename=None, episode=None, run_checks=True):
if filename is not None:
episode = load_episode(filename)
if len(episode['reward'].shape) == 1:
episode['reward'] = episode['reward'].reshape(-1, 1)
if 'discount' not in episode:
episode['discount'] = (1 - episode['is_terminal']).reshape(-1, 1).astype(np.float32)
#
if run_checks:
for spec_set in [self._data_specs, self._meta_specs]:
for spec in spec_set:
if type(spec) in [dict, Dict]:
for k,v in spec.items():
value = episode[k][0]
assert v.shape == value.shape and v.dtype == value.dtype, f"for ({k}) expected {v.dtype, v.shape, }), received ({value.dtype, value.shape, })"
else:
value = episode[spec.name][0]
assert spec.shape == value.shape and spec.dtype == value.dtype, f"for ({spec.name}) expected {spec.dtype, spec.shape, }), received ({value.dtype, value.shape, })"
if not episode:
return False
length = eplen(episode)
if run_checks:
for k in episode:
assert len(episode[k]) == length, f'Found {episode[k].shape} VS eplen: {length}'
# Enforce limit
while self._loaded_steps + length > self._capacity:
for k in self._complete_eps:
self._complete_eps[k].pop(0)
removed_len, self._episode_lens = self._episode_lens[0], self._episode_lens[1:]
self._loaded_steps -= removed_len
self._loaded_episodes -= 1
# add episode
for k,v in episode.items():
if k not in self._complete_eps:
if self._ignore_extra_keys: continue
else: raise KeyError("Extra key ", k)
self._complete_eps[k].append(v)
self._episode_lens = np.append(self._episode_lens, length)
self._loaded_steps += length
self._loaded_episodes += 1
return True
def __iter__(self):
while True:
sequences, batch_size, batch_length = self._loaded_episodes, self.batch_size, self._length
b_indices = np.random.randint(0, sequences, size=batch_size)
t_indices = np.random.randint(np.zeros(batch_size) + self._min_t_sampling, self._episode_lens[b_indices]-batch_length+1, size=batch_size)
t_ranges = np.repeat( np.expand_dims(np.arange(0, batch_length,), 0), batch_size, axis=0) + np.expand_dims(t_indices, 1)
chunk = {}
for k in self._complete_eps:
chunk[k] = np.stack([self._complete_eps[k][b][t] for b,t in zip(b_indices, t_ranges)])
for k in chunk:
chunk[k] = torch.as_tensor(chunk[k], device=self.device)
yield chunk
@utils.retry
def save_episode(self, directory, episode):
idx = self._total_episodes
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
identifier = str(uuid.uuid4().hex)
length = eplen(episode)
filename = directory / f'{idx}-{timestamp}-{identifier}-{length}.npz'
with io.BytesIO() as f1:
np.savez_compressed(f1, **episode)
f1.seek(0)
with filename.open('wb') as f2:
f2.write(f1.read())
return filename
def load_episode(filename):
try:
with filename.open('rb') as f:
episode = np.load(f, allow_pickle=True)
episode = {k: episode[k] for k in episode.keys()}
except Exception as e:
print(f'Could not load episode {str(filename)}: {e}')
return False
return episode
def count_episodes(directory):
filenames = list(directory.glob('*.npz'))
num_episodes = len(filenames)
if num_episodes == 0 : return 0, 0
if len(filenames) > 0 and "-" in str(filenames[0]):
num_steps = sum(int(str(n).split('-')[-1][:-4]) - 1 for n in filenames)
last_episode = sorted(list(int(n.stem.split('-')[0]) for n in filenames))[-1]
else:
num_steps = sum(int(str(n).split('_')[-1][:-4]) - 1 for n in filenames)
last_episode = sorted(list(int(n.stem.split('_')[0]) for n in filenames))[-1]
return last_episode, num_steps
def load_filenames(directory, capacity=None, minlen=1, load_first=False, load_recursive=False):
# The returned directory from filenames to episodes is guaranteed to be in
# temporally sorted order.
if load_recursive:
filenames = sorted(directory.glob('**/*.npz'))
else:
filenames = sorted(directory.glob('*.npz'))
if capacity:
num_steps = 0
num_episodes = 0
ordered_filenames = filenames if load_first else reversed(filenames)
for filename in ordered_filenames:
if "-" in str(filename):
length = int(str(filename).split('-')[-1][:-4])
else:
length = int(str(filename).split('_')[-1][:-4])
num_steps += length
num_episodes += 1
if num_steps >= capacity:
break
if load_first:
filenames = filenames[:num_episodes]
else:
filenames = filenames[-num_episodes:]
return filenames
def convert(value):
value = np.array(value)
if np.issubdtype(value.dtype, np.floating):
return value.astype(np.float32)
elif np.issubdtype(value.dtype, np.signedinteger):
return value.astype(np.int32)
elif np.issubdtype(value.dtype, np.uint8):
return value.astype(np.uint8)
return value
def eplen(episode):
return len(episode['action'])
def make_replay_loader(buffer, batch_size,):
buffer.batch_size = batch_size
return DataLoader(buffer,
batch_size=None,
# NOTE: do not use any workers,
# as they don't get copies of the replay buffer (requires different implementation)
)