PyDreamerV1 / utils /buffer.py
minhphd's picture
Upload 30 files
ce3feed verified
"""
Author: Minh Pham-Dinh
Created: Jan 26th, 2024
Last Modified: Feb 5th, 2024
Email: mhpham26@colby.edu
Description:
File containing the ReplayBuffer that will be used in Dreamer.
The implementation is based on:
Hafner et al., "Dream to Control: Learning Behaviors by Latent Imagination," 2019.
[Online]. Available: https://arxiv.org/abs/1912.01603
"""
import numpy as np
from gymnasium import Env
import torch
from addict import Dict
class ReplayBuffer:
def __init__(self, capacity, obs_size, action_size):
# check if the env is gymnasium or dmc
self.obs_size = obs_size
self.action_size = action_size
# from SimpleDreamer implementation, saving memory
state_type = np.uint8 if len(self.obs_size) < 3 else np.float32
self.observation = np.zeros((capacity, ) + self.obs_size, dtype=state_type)
self.actions = np.zeros((capacity, ) + self.action_size, dtype=np.float32)
self.rewards = np.zeros((capacity, 1), dtype=np.float32)
self.dones = np.zeros((capacity, 1), dtype=np.float32)
self.pointer = 0
self.full = False
print(f'''
-----------initialized memory----------
obs_buffer_shape: {self.observation.shape}
actions_buffer_shape: {self.actions.shape}
rewards_buffer_shape: {self.rewards.shape}
dones_buffer_shape: {self.dones.shape}
----------------------------------------
''')
def add(self, obs, action, reward, done):
"""Add method for buffer
Args:
obs (np.array): current observation
action (np.array): action taken
reward (float): reward received after action
next_obs (np.array): next observation
done (bool): boolean value of termination or truncation
"""
self.observation[self.pointer] = obs
self.actions[self.pointer] = action
self.rewards[self.pointer] = reward
self.dones[self.pointer] = done
self.pointer = (self.pointer + 1) % self.observation.shape[0]
if self.pointer == 0:
self.full = True
def sample(self, batch_size, seq_len, device):
"""
Samples batches of experiences of fixed sequence length from the replay buffer,
taking into account the circular nature of the buffer to avoid crossing the
"end" of the buffer when it is full.
This method ensures that sampled sequences are continuous and do not wrap around
the end of the buffer, maintaining the temporal integrity of experiences. This is
particularly important when the buffer is full, and the pointer marks the boundary
between the newest and oldest data in the buffer.
Args:
batch_size (int): The number of sequences to sample.
seq_len (int): The length of each sequence to sample.
device (torch.device): The device on which the sampled data will be loaded.
Raises:
Exception: If there is not enough data in the buffer to sample a full sequence.
Returns:
Dict: A dictionary containing the sampled sequences of observations, actions,
rewards, and dones. Each item in the dictionary is a tensor of shape
(batch_size, seq_len, feature_dimension), except for 'dones' which is of shape
(batch_size, seq_len, 1).
Notes:
- The method handles different scenarios based on the buffer's state (full or not)
and the pointer's position to ensure valid sequence sampling without wrapping.
- When the buffer is not full, sequences can start from index 0 up to the
index where `seq_len` sequences can fit without surpassing the current pointer.
- When the buffer is full, the method ensures sequences do not start in a way
that would cause them to wrap around past the pointer, effectively crossing
the boundary between the newest and oldest data.
- This approach guarantees the sampled sequences respect the temporal order
and continuity necessary for algorithms that rely on sequences of experiences.
"""
# Ensure there's enough data to sample
if self.pointer < seq_len and not self.full:
raise Exception('not enough data to sample')
# detail: handling different cases for circular sampling
if self.full:
if self.pointer - seq_len < 0:
valid_range = np.arange(self.pointer, self.observation.shape[0] - (self.pointer - seq_len) + 1)
else:
range_1 = np.arange(0, self.pointer - seq_len + 1)
range_2 = np.arange(self.pointer, self.observation.shape[0])
valid_range = np.concatenate((range_1, range_2), -1)
else:
valid_range = np.arange(0, self.pointer-seq_len+1)
start_index = np.random.choice(valid_range, (batch_size, 1))
seq_len = np.arange(seq_len)
sample_idcs = (start_index + seq_len) % self.observation.shape[0]
batch = Dict()
batch.obs = torch.from_numpy(self.observation[sample_idcs]).to(device)
batch.actions = torch.from_numpy(self.actions[sample_idcs]).to(device)
batch.rewards = torch.from_numpy(self.rewards[sample_idcs]).to(device)
batch.dones = torch.from_numpy(self.dones[sample_idcs]).to(device)
return batch
def clear(self, ):
self.pointer = 0
self.full = False
def __len__(self, ):
return self.pointer