remdm-minihack / src /buffer.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""Replay buffer with offline-protected FIFO eviction.
Ported from minihack_reference/src/buffer.py. Stores observation-action
windows of fixed length ``seq_len``. Offline data is pinned at the front
and never evicted; online samples use FIFO.
"""
from __future__ import annotations
import numpy as np
class ReplayBuffer:
"""Fixed-capacity buffer with offline-protected FIFO eviction.
Offline samples (loaded once via ``load_offline_data``) are pinned
and never evicted. Online samples added via ``add`` are FIFO-evicted
when the total count exceeds ``capacity``.
Args:
capacity: Maximum total number of windows.
seq_len: Action-sequence window length.
pad_token: Token used to pad short sequences.
"""
def __init__(
self, capacity: int, seq_len: int, pad_token: int,
) -> None:
self._capacity = capacity
self._seq_len = seq_len
self._pad_token = pad_token
# Each element: (local [9,9], global [21,79], actions [seq_len])
self._offline: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
self._online: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
# Stacked array cache for fast sampling
self._cache_valid = False
self._cached_local: np.ndarray | None = None
self._cached_global: np.ndarray | None = None
self._cached_actions: np.ndarray | None = None
# ── Offline data ─────────────────────────────────────────────
def load_offline_data(
self,
data: dict | list,
allowed_envs: list[str],
metadata: dict | None = None,
) -> None:
"""Load pre-collected trajectories and slice into windows.
Supports two dataset formats:
**New format** (dict): ``{"trajectories": [...]}`` where each entry
is a dict with ``"local"``, ``"global"``, ``"actions"``, ``"env_id"``.
**Legacy format** (list): Flat list of ``((local, global), action_seq)``
tuples produced by the reference pipeline (pre-windowed, already
``seq_len``-length). Env filtering uses an optional *metadata* dict
with a ``"samples_per_env"`` key mapping env IDs to sample counts.
Args:
data: Dataset in new dict format or legacy list format.
allowed_envs: Only samples from these env IDs are kept.
metadata: Optional sidecar metadata for legacy format env
filtering. Ignored for the new format.
"""
if isinstance(data, list):
self._load_legacy_offline_data(data, allowed_envs, metadata)
return
trajectories = data.get("trajectories", [data])
for traj in trajectories:
if traj.get("env_id", "") not in allowed_envs:
continue
windows = self._slice_trajectory(traj)
self._offline.extend(windows)
# Truncate to capacity
if len(self._offline) > self._capacity:
self._offline = self._offline[: self._capacity]
self._invalidate_cache()
def _load_legacy_offline_data(
self,
data: list,
allowed_envs: list[str],
metadata: dict | None = None,
) -> None:
"""Load reference-format datasets (pre-windowed tuples).
Args:
data: List of ``((local_crop, global_map), action_seq)`` tuples.
``local_crop`` is ``[9, 9]``, ``global_map`` is ``[21, 79]``,
``action_seq`` is a sequence of length ``seq_len``.
allowed_envs: Env IDs to retain.
metadata: Optional dict with ``"samples_per_env"`` key mapping
env IDs to per-env sample counts for precise filtering.
"""
allowed = set(allowed_envs)
if metadata and "samples_per_env" in metadata:
# Build a per-sample env_id index from the metadata ordering
sample_to_env: list[str] = []
for env_id in sorted(metadata["samples_per_env"].keys()):
count = metadata["samples_per_env"][env_id]
sample_to_env.extend([env_id] * count)
for i, sample in enumerate(data):
env_id = (
sample_to_env[i] if i < len(sample_to_env) else None
)
if env_id is None or env_id in allowed:
self._offline.append(self._unpack_legacy_sample(sample))
else:
# No metadata β€” keep all samples (caller is responsible for
# pre-filtering)
for sample in data:
self._offline.append(self._unpack_legacy_sample(sample))
if len(self._offline) > self._capacity:
self._offline = self._offline[: self._capacity]
self._invalidate_cache()
@staticmethod
def _unpack_legacy_sample(
sample: tuple,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Convert a legacy ``((local, global), action_seq)`` sample.
Args:
sample: Tuple of ``(state, action_seq)`` where state is
``(local_crop, global_map)``.
Returns:
``(local [9,9], global [21,79], actions [seq_len])`` as
numpy int16/int64 arrays.
"""
(local, glb), action_seq = sample
return (
np.asarray(local, dtype=np.int16),
np.asarray(glb, dtype=np.int16),
np.asarray(action_seq, dtype=np.int64),
)
# ── Online data ──────────���───────────────────────────────────
def _invalidate_cache(self) -> None:
"""Mark the stacked array cache as stale."""
self._cache_valid = False
def _ensure_cache(self) -> None:
"""Rebuild stacked arrays from offline + online windows."""
if self._cache_valid:
return
combined = self._offline + self._online
if not combined:
return
n = len(combined)
l0, g0, a0 = combined[0]
self._cached_local = np.empty(
(n, *l0.shape), dtype=l0.dtype,
)
self._cached_global = np.empty(
(n, *g0.shape), dtype=g0.dtype,
)
self._cached_actions = np.empty(
(n, *a0.shape), dtype=a0.dtype,
)
for i, (l, g, a) in enumerate(combined):
self._cached_local[i] = l
self._cached_global[i] = g
self._cached_actions[i] = a
self._cache_valid = True
def add(self, trajectory: dict) -> None:
"""Add a trajectory, sliced into overlapping windows.
FIFO-evicts oldest online samples when over capacity.
Args:
trajectory: Dict with ``"local"`` ``[T,9,9]``,
``"global"`` ``[T,21,79]``, ``"actions"`` ``[T]``.
"""
windows = self._slice_trajectory(trajectory)
self._online.extend(windows)
max_online = self._capacity - len(self._offline)
if len(self._online) > max_online:
excess = len(self._online) - max_online
self._online = self._online[excess:]
self._invalidate_cache()
# ── Sampling ─────────────────────────────────────────────────
def sample(
self, batch_size: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
"""Random sample from offline + online combined.
Args:
batch_size: Number of windows to sample.
Returns:
``(local [B,9,9], global [B,21,79], actions [B,seq_len])``
as numpy arrays, or ``None`` if the buffer is empty.
"""
if len(self) == 0:
return None
self._ensure_cache()
if self._cached_local is None:
return None
indices = np.random.randint(0, len(self), size=batch_size)
return (
self._cached_local[indices],
self._cached_global[indices],
self._cached_actions[indices],
)
# ── Properties ─────────��─────────────────────────────────────
def __len__(self) -> int:
"""Total number of windows (offline + online)."""
return len(self._offline) + len(self._online)
@property
def n_offline(self) -> int:
"""Number of pinned offline windows."""
return len(self._offline)
@property
def offline_size(self) -> int:
"""Number of pinned offline windows (alias)."""
return len(self._offline)
# ── Internals ───────────────────────────────────────────���────
def _slice_trajectory(
self, traj: dict,
) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""Slice a trajectory into overlapping seq_len windows.
Args:
traj: Trajectory dict with ``"local"``, ``"global"``,
``"actions"`` arrays.
Returns:
List of ``(local, global, actions)`` tuples.
"""
local_arr = np.asarray(traj["local"])
global_arr = np.asarray(traj["global"])
actions_arr = np.asarray(traj["actions"])
T = len(actions_arr)
windows: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
for start in range(T):
end = start + self._seq_len
if end <= T:
a = actions_arr[start:end]
else:
a = np.full(self._seq_len, self._pad_token, dtype=np.int64)
a[: T - start] = actions_arr[start:]
# Use the observation at the window start
l = local_arr[min(start, len(local_arr) - 1)]
g = global_arr[min(start, len(global_arr) - 1)]
windows.append((l.copy(), g.copy(), a))
return windows