| """Noise schedule functions for MDLM diffusion. |
| |
| Ported from the Craftax JAX implementation (src/diffusion/schedules.py). |
| All functions operate on PyTorch tensors and are pure (no global state). |
| |
| Convention: alpha(t) is the fraction of tokens that remain *unmasked*. |
| - alpha(0) = 1.0 (fully clean) |
| - alpha(1) = 0.0 (fully masked) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import Callable |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| def linear_schedule(t: Tensor) -> Tensor: |
| """Linear noise schedule: alpha(t) = 1 - t. |
| |
| Args: |
| t: Diffusion time in [0, 1]. Any shape. |
| |
| Returns: |
| Retention probability alpha_t, same shape as *t*. |
| """ |
| return 1.0 - t |
|
|
|
|
| def cosine_schedule(t: Tensor) -> Tensor: |
| """Cosine noise schedule: alpha(t) = cos(pi/2 * t)^2. |
| |
| Args: |
| t: Diffusion time in [0, 1]. Any shape. |
| |
| Returns: |
| Retention probability alpha_t, same shape as *t*. |
| """ |
| return torch.cos(t * (math.pi / 2.0)) ** 2 |
|
|
|
|
| _SCHEDULE_MAP: dict[str, Callable[[Tensor], Tensor]] = { |
| "linear": linear_schedule, |
| "cosine": cosine_schedule, |
| } |
|
|
|
|
| def get_schedule(name: str) -> Callable[[Tensor], Tensor]: |
| """Look up a noise schedule by name. |
| |
| Args: |
| name: One of ``"linear"`` or ``"cosine"``. |
| |
| Returns: |
| The schedule function ``alpha(t)``. |
| |
| Raises: |
| KeyError: If *name* is not registered. |
| """ |
| if name not in _SCHEDULE_MAP: |
| raise KeyError( |
| f"Unknown schedule '{name}'. " |
| f"Available: {list(_SCHEDULE_MAP.keys())}" |
| ) |
| return _SCHEDULE_MAP[name] |
|
|
|
|
| def alpha_prime( |
| t: Tensor, |
| schedule_fn: Callable[[Tensor], Tensor], |
| eps: float = 1e-5, |
| ) -> Tensor: |
| """Numerical derivative d(alpha)/dt via central difference. |
| |
| Args: |
| t: Diffusion time in [0, 1]. Any shape. |
| schedule_fn: Noise schedule returning alpha(t). |
| eps: Half-width for finite-difference stencil. |
| |
| Returns: |
| Approximate derivative, same shape as *t*. |
| """ |
| t_clamped = t.clamp(eps, 1.0 - eps) |
| return (schedule_fn(t_clamped + eps) - schedule_fn(t_clamped - eps)) / ( |
| 2.0 * eps |
| ) |
|
|