remdm-minihack / src /diffusion /schedules.py
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
"""Noise schedule functions for MDLM diffusion.
Ported from the Craftax JAX implementation (src/diffusion/schedules.py).
All functions operate on PyTorch tensors and are pure (no global state).
Convention: alpha(t) is the fraction of tokens that remain *unmasked*.
- alpha(0) = 1.0 (fully clean)
- alpha(1) = 0.0 (fully masked)
"""
from __future__ import annotations
import math
from typing import Callable
import torch
from torch import Tensor
def linear_schedule(t: Tensor) -> Tensor:
"""Linear noise schedule: alpha(t) = 1 - t.
Args:
t: Diffusion time in [0, 1]. Any shape.
Returns:
Retention probability alpha_t, same shape as *t*.
"""
return 1.0 - t
def cosine_schedule(t: Tensor) -> Tensor:
"""Cosine noise schedule: alpha(t) = cos(pi/2 * t)^2.
Args:
t: Diffusion time in [0, 1]. Any shape.
Returns:
Retention probability alpha_t, same shape as *t*.
"""
return torch.cos(t * (math.pi / 2.0)) ** 2
_SCHEDULE_MAP: dict[str, Callable[[Tensor], Tensor]] = {
"linear": linear_schedule,
"cosine": cosine_schedule,
}
def get_schedule(name: str) -> Callable[[Tensor], Tensor]:
"""Look up a noise schedule by name.
Args:
name: One of ``"linear"`` or ``"cosine"``.
Returns:
The schedule function ``alpha(t)``.
Raises:
KeyError: If *name* is not registered.
"""
if name not in _SCHEDULE_MAP:
raise KeyError(
f"Unknown schedule '{name}'. "
f"Available: {list(_SCHEDULE_MAP.keys())}"
)
return _SCHEDULE_MAP[name]
def alpha_prime(
t: Tensor,
schedule_fn: Callable[[Tensor], Tensor],
eps: float = 1e-5,
) -> Tensor:
"""Numerical derivative d(alpha)/dt via central difference.
Args:
t: Diffusion time in [0, 1]. Any shape.
schedule_fn: Noise schedule returning alpha(t).
eps: Half-width for finite-difference stencil.
Returns:
Approximate derivative, same shape as *t*.
"""
t_clamped = t.clamp(eps, 1.0 - eps)
return (schedule_fn(t_clamped + eps) - schedule_fn(t_clamped - eps)) / (
2.0 * eps
)