remdm-minihack / src /diffusion /sampling.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""ReMDM reverse denoising with remasking strategies.
Ported from the Craftax JAX implementation (src/diffusion/sampling.py).
Implements MaskGIT-style progressive unmasking with optional stochastic
remasking (ReMDM) using three strategy variants.
"""
from __future__ import annotations
from types import SimpleNamespace
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from src.diffusion.schedules import get_schedule
# NLE hazard glyph IDs and char codes (walls, locked doors, lava, water)
_HAZARD_GLYPHS: frozenset[int] = frozenset({2359, 2360, 2389, 2390})
_HAZARD_CHARS: frozenset[int] = frozenset(
{ord("|"), ord("-"), ord("+"), ord("L"), ord("W")}
)
# Cardinal action → (dy, dx) offsets
_CARDINAL_OFFSETS: dict[int, tuple[int, int]] = {
0: (-1, 0), 1: (0, 1), 2: (1, 0), 3: (0, -1),
}
_N_PHYSICS_CHECK = 8 # only inspect the first N plan positions
def _check_hazard(local_crop: np.ndarray, action: int) -> bool:
"""Return True if *action* from the agent's centre steps into a hazard.
Args:
local_crop: ``[crop_size, crop_size]`` glyph array.
action: Cardinal action index (0=N, 1=E, 2=S, 3=W).
Returns:
``True`` when the target cell contains a hazard glyph.
"""
if action not in _CARDINAL_OFFSETS:
return False
cs = local_crop.shape[0]
cy, cx = cs // 2, cs // 2
dy, dx = _CARDINAL_OFFSETS[action]
ny, nx = cy + dy, cx + dx
if not (0 <= ny < cs and 0 <= nx < cs):
return True
glyph = int(local_crop[ny, nx])
return glyph in _HAZARD_GLYPHS or glyph in _HAZARD_CHARS
def top_k_filter(logits: Tensor, k: int) -> Tensor:
"""Zero out all but the top-k logits per position.
Args:
logits: Raw logits. Shape ``[..., V]``.
k: Number of top entries to keep.
Returns:
Filtered logits with non-top-k set to ``-inf``.
"""
if k <= 0 or k >= logits.shape[-1]:
return logits
topk_vals, _ = logits.topk(k, dim=-1) # [..., k]
threshold = topk_vals[..., -1:] # [..., 1]
return logits.masked_fill(logits < threshold, float("-inf"))
def _compute_remask_prob(
strategy: str,
eta: float,
sigma_max: float,
confidence: Tensor | None,
) -> Tensor | float:
"""Compute per-token remasking probability.
Args:
strategy: One of ``"rescale"``, ``"cap"``, ``"conf"``.
eta: Base remasking strength hyperparameter.
sigma_max: ``1 - alpha_t(ratio)`` at current step.
confidence: Per-token confidence scores. Shape ``[B, L]``.
Required only for the ``"conf"`` strategy.
Returns:
Scalar or ``[B, L]`` tensor of remasking probabilities.
"""
if strategy == "rescale":
return eta * sigma_max
if strategy == "cap":
return min(eta, sigma_max)
if strategy == "conf":
assert confidence is not None, "conf strategy requires confidence"
return eta * sigma_max * (1.0 - confidence)
raise ValueError(f"Unknown remask strategy: {strategy}")
@torch.no_grad()
def remdm_sample(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
physics_aware: bool = True,
blind_global: bool = False,
return_analytics: bool = False,
num_steps: int | None = None,
) -> Tensor | tuple[Tensor, list, list[float], list[int]]:
"""Generate action sequences via iterative ReMDM denoising.
Args:
model: Denoising model with forward signature
``(local_obs, global_obs, action_seq, t_discrete) -> dict``.
local_obs: Local crop observations. Shape ``[B, 9, 9]``.
global_obs: Global map observations. Shape ``[B, 21, 79]``.
cfg: Config namespace with ``seq_len``, ``mask_token``,
``action_dim``, ``diffusion_steps_eval``, ``temperature``,
``top_k``, ``eta``, ``remask_strategy``, ``noise_schedule``.
device: Torch device.
physics_aware: If ``True``, soft-penalise hazardous cardinal actions
by overriding their confidence to ``0.001`` before commitment
ranking. Only checks the first ``_N_PHYSICS_CHECK`` positions.
blind_global: If ``True``, zero out the global map observation
(local-only ablation).
return_analytics: If ``True``, also return per-step analytics as
``(seq, path_per_step, tracking_confidence, tracking_masked)``.
num_steps: Override number of denoising steps (default uses
``cfg.diffusion_steps_eval``).
Returns:
When ``return_analytics=False`` (default): fully committed action
sequence of shape ``[B, seq_len]``, int64, with no MASK tokens.
When ``return_analytics=True``: tuple
``(seq, path_per_step, tracking_confidence, tracking_masked_count)``
where ``path_per_step`` is a list of ``[seq_len]`` numpy arrays,
``tracking_confidence`` a list of per-step avg unmasked confidence
floats, and ``tracking_masked_count`` a list of masked-token counts.
"""
B = local_obs.shape[0]
seq_len = cfg.seq_len
mask_token = cfg.mask_token
action_dim = cfg.action_dim
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
schedule_fn = get_schedule(cfg.noise_schedule)
min_keep = max(1, int(seq_len * 0.10)) # Safety Net: always unmask ≥10%
local_obs = local_obs.to(device)
global_obs = global_obs.to(device)
if blind_global:
global_obs = torch.zeros_like(global_obs)
# Pre-compute numpy local crops for physics checks (CPU, batch loop)
local_np: np.ndarray | None = None # [B, crop, crop]
if physics_aware:
local_np = local_obs.cpu().numpy()
# Analytics buffers (only populated when return_analytics=True)
path_per_step: list[np.ndarray] = []
tracking_confidence: list[float] = []
tracking_masked_count: list[int] = []
# Start fully masked
seq = torch.full(
(B, seq_len), mask_token, dtype=torch.long, device=device
)
for k in range(1, K + 1):
ratio = k / K
# Pass as tensor (not Python int) to avoid torch.compile recompilation
t_discrete = torch.full(
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
dtype=torch.long, device=device,
)
# Forward pass
out = model(local_obs, global_obs, seq, t_discrete)
logits = out["actions"] # [B, seq_len, vocab]
# Mask invalid action tokens (indices >= action_dim)
logits[:, :, action_dim:] = float("-inf")
# Temperature scaling
logits = logits / cfg.temperature
# Top-K filtering
logits = top_k_filter(logits, cfg.top_k)
# Sample predictions
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
preds = Categorical(probs=probs).sample() # [B, seq_len]
# Confidence: probability of the sampled token
conf = probs.gather(
-1, preds.unsqueeze(-1)
).squeeze(-1) # [B, seq_len]
# Physics softener: demote hazardous cardinal actions to conf=0.001
if physics_aware and local_np is not None:
preds_np = preds.cpu().numpy() # [B, seq_len]
conf_override = conf.clone()
for b in range(B):
crop_b = np.asarray(local_np[b]) # [crop, crop]
for pos in range(min(_N_PHYSICS_CHECK, seq_len)):
action = int(preds_np[b, pos])
if _check_hazard(crop_b, action):
conf_override[b, pos] = 0.001
conf = conf_override
is_masked = seq == mask_token # [B, seq_len]
if k < K:
# MaskGIT progressive unmasking with min-keep guarantee
n_unmask = max(min_keep, max(1, int(seq_len * ratio)))
# Set confidence of non-masked positions to -1 so they
# are not selected for unmasking
unmask_scores = conf.clone()
unmask_scores[~is_masked] = -1.0
# For each batch element, unmask top-confidence masked positions
_, topk_indices = unmask_scores.topk(
n_unmask, dim=-1
) # [B, n_unmask]
# Build scatter mask for positions to unmask
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
unmask_mask.scatter_(1, topk_indices, True)
unmask_mask = unmask_mask & is_masked # only unmask masked pos
seq = torch.where(unmask_mask, preds, seq)
# ReMDM stochastic remasking of committed (non-masked) positions
is_committed = seq != mask_token # [B, seq_len]
alpha_t_ratio = schedule_fn(
torch.tensor(ratio, device=device)
)
sigma_max = (1.0 - alpha_t_ratio).item()
remask_prob = _compute_remask_prob(
cfg.remask_strategy, cfg.eta, sigma_max, conf
)
if isinstance(remask_prob, Tensor):
do_remask = (
torch.rand_like(conf) < remask_prob
) & is_committed
else:
do_remask = (
torch.rand(B, seq_len, device=device) < remask_prob
) & is_committed
seq = torch.where(do_remask, mask_token, seq)
else:
# Final step: commit all remaining MASK tokens
seq = torch.where(is_masked, preds, seq)
# Analytics tracking
if return_analytics:
path_per_step.append(seq[0].cpu().numpy().copy())
still_masked = (seq[0] == mask_token)
unmasked_conf = conf[0][~still_masked]
avg_conf = (
unmasked_conf.mean().item()
if unmasked_conf.numel() > 0 else 0.0
)
tracking_confidence.append(avg_conf)
tracking_masked_count.append(int(still_masked.sum().item()))
assert (seq != mask_token).all(), (
"remdm_sample produced MASK tokens in final output"
)
if return_analytics:
return seq, path_per_step, tracking_confidence, tracking_masked_count
return seq
@torch.no_grad()
def greedy_sample(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
blind_global: bool = False,
num_steps: int | None = None,
) -> Tensor:
"""Greedy (argmax) MaskGIT sampling — no temperature, top-K, or remasking.
Used by ``DataCollector`` during DAgger for deterministic rollouts,
matching the reference ``run_model_episode`` behaviour.
Args:
model: Denoising model.
local_obs: Shape ``[B, 9, 9]``.
global_obs: Shape ``[B, 21, 79]``.
cfg: Config namespace.
device: Torch device.
blind_global: Zero out global map (local-only ablation).
Returns:
Fully committed action sequence ``[B, seq_len]``, int64.
"""
B = local_obs.shape[0]
seq_len = cfg.seq_len
mask_token = cfg.mask_token
action_dim = cfg.action_dim
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
local_obs = local_obs.to(device)
global_obs = global_obs.to(device)
if blind_global:
global_obs = torch.zeros_like(global_obs)
seq = torch.full(
(B, seq_len), mask_token, dtype=torch.long, device=device,
)
for k in range(1, K + 1):
ratio = k / K
t_discrete = torch.full(
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
dtype=torch.long, device=device,
)
out = model(local_obs, global_obs, seq, t_discrete)
logits = out["actions"] # [B, seq_len, vocab]
# Mask invalid action tokens
logits[:, :, action_dim:] = float("-inf")
# Greedy: argmax over softmax (no temperature, no top-K)
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
confidences, preds = probs.max(dim=-1) # [B, seq_len] each
# MaskGIT progressive unmasking by confidence
num_to_unmask = max(1, int(seq_len * ratio))
is_masked = seq == mask_token # [B, seq_len]
# Score only masked positions for unmasking
scores = confidences.clone()
scores[~is_masked] = -1.0
_, topk_idx = scores.topk(num_to_unmask, dim=-1)
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
unmask_mask.scatter_(1, topk_idx, True)
unmask_mask = unmask_mask & is_masked
seq = torch.where(unmask_mask, preds, seq)
# No remasking in greedy mode
# Force-commit any remaining masked tokens
still_masked = seq == mask_token
if still_masked.any():
t_zero = torch.zeros(B, dtype=torch.long, device=device)
out = model(local_obs, global_obs, seq, t_zero)
logits = out["actions"]
logits[:, :, action_dim:] = float("-inf")
preds = logits.argmax(dim=-1)
seq = torch.where(still_masked, preds, seq)
return seq
def select_action(
model: torch.nn.Module,
local_obs: Tensor,
global_obs: Tensor,
cfg: SimpleNamespace,
device: torch.device | str,
physics_aware: bool = True,
blind_global: bool = False,
) -> int:
"""Sample a single action from a length-1 batch.
Args:
model: Denoising model.
local_obs: Shape ``[9, 9]`` or ``[1, 9, 9]``.
global_obs: Shape ``[21, 79]`` or ``[1, 21, 79]``.
cfg: Config namespace.
device: Torch device.
physics_aware: Forward to ``remdm_sample``.
blind_global: Forward to ``remdm_sample``.
Returns:
The first action of the generated plan (int).
"""
if local_obs.ndim == 2:
local_obs = local_obs.unsqueeze(0)
if global_obs.ndim == 2:
global_obs = global_obs.unsqueeze(0)
seq = remdm_sample(
model, local_obs, global_obs, cfg, device,
physics_aware=physics_aware, blind_global=blind_global,
)
return seq[0, 0].item()