World_Model / URSA /src /distill /utils_ursa_inputs.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------
"""Utility functions for constructing URSA model inputs during distillation.
This module mirrors the token-splicing and RoPE-position logic from
URSAPipeline.__call__ and URSATrainPipeline.process_inputs so that
student/aux/teacher always see the exact same input distribution.
Key design facts (verified from source):
- transformer.config.lm_head_size = 64000 -> logit output dim (codebook_size)
- transformer.config.lm_vocab_size = 151669 -> text-vocab offset for visual tokens
- transformer.config.bov_token_id = 151652 -> beginning-of-video sentinel
- Input visual token IDs are shifted: stored as (raw_code + lm_vocab_size)
- BOV sentinel is prepended to the visual token block
- Causal slice to recover visual logits: logits[:, -(N+1):-1] where N = T*H*W
"""
from typing import Tuple
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Latent shape helpers
# ---------------------------------------------------------------------------
def compute_latents_shape(
num_frames: int,
height: int,
width: int,
temporal_stride: int = 4,
spatial_stride: int = 8,
) -> Tuple[int, int, int]:
"""Return the VQ-token grid (T, H, W) matching URSAPipeline's convention.
Matches the formula in URSAPipeline.__call__:
T = (num_frames - 1) // temporal_stride + 1
H = height // spatial_stride
W = width // spatial_stride
"""
T = (num_frames - 1) // temporal_stride + 1
H = height // spatial_stride
W = width // spatial_stride
return T, H, W
# ---------------------------------------------------------------------------
# Core input builder
# ---------------------------------------------------------------------------
def build_ursa_inputs(
transformer,
txt_ids: torch.Tensor,
visual_tokens: torch.Tensor,
latents_shape: Tuple[int, int, int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""Construct (input_ids, rope_pos, N) exactly as URSAPipeline does.
This is the single source-of-truth for all three models
(teacher / aux / student) so their input distributions match.
Args:
transformer: The URSATransformer3DModel (read config from it).
txt_ids: Tokenised prompts, shape [B, L].
visual_tokens: Raw codebook indices, shape [B, T, H, W] or [B, N], dtype long.
latents_shape: (T, H, W) tuple – shape of one video's latent grid.
device: Target device.
Returns:
input_ids: [B, L + N + 1], long (N = T*H*W)
rope_pos: [B, L + N + 1, 3], int32
N: number of visual tokens per sample (T*H*W)
Notes:
- BOV token is inserted at position L (just before the visual tokens).
- Visual token IDs are shifted by lm_vocab_size before being concatenated.
- rope_pos is batched (training convention), not the 2-D inference convention.
"""
B, L = txt_ids.shape
# -- Config values ---------------------------------------------------
# PITFALL 1: always read from config, never hard-code.
bov_token_id = transformer.config.bov_token_id
# lm_vocab_size == len(tokenizer): the visual-token vocab offset.
latent_shift = transformer.config.lm_vocab_size
T, H, W = latents_shape
N = T * H * W
# -- Input validation ------------------------------------------------
assert visual_tokens.dtype == torch.long, \
f"build_ursa_inputs: visual_tokens must be long, got {visual_tokens.dtype}"
assert visual_tokens.numel() == B * N, (
f"build_ursa_inputs: visual_tokens has {visual_tokens.numel()} elements, "
f"expected B*N = {B}*{N} = {B*N}"
)
# -- Visual token block ----------------------------------------------
# Flatten to [B, N] so pad/cat are straightforward.
latents_flat = visual_tokens.view(B, N).to(device) # [B, N], long
# Shift raw codebook indices into the visual-vocab region and prepend BOV.
# Mirrors: img_ids = pad(latents_flat + latent_shift, (1,0), value=bov_token_id)
img_ids = F.pad(latents_flat + latent_shift, (1, 0), value=bov_token_id) # [B, N+1]
# -- Full input sequence: [txt | bov | vis_0 ... vis_{N-1}] ----------
input_ids = torch.cat([txt_ids.to(device), img_ids], dim=1) # [B, L+N+1]
# -- RoPE positions --------------------------------------------------
# Mirrors URSAPipeline:
# txt_pos = arange(L).view(-1,1).expand(-1,3) -> [L, 3]
# blk_pos = flex_rope.get_pos(latents_shape, L) -> [1, N+1, 3]
# rope_pos = cat([txt_pos, blk_pos[0]]) -> [L+N+1, 3]
# Then batch-expand (training convention):
# rope_pos = rope_pos.unsqueeze(0).expand(B,-1,-1).contiguous() -> [B, L+N+1, 3]
txt_pos = torch.arange(L, device=device).view(-1, 1).expand(-1, 3) # [L, 3]
blk_pos = transformer.model.flex_rope.get_pos(latents_shape, txt_pos.size(0)) # [1, N+1, 3]
rope_pos_1d = torch.cat([txt_pos, blk_pos[0].to(device)], dim=0) # [L+N+1, 3]
rope_pos = rope_pos_1d.unsqueeze(0).expand(B, -1, -1).contiguous() # [B, L+N+1, 3]
# -- Output shape assertions -----------------------------------------
expected_seq_len = L + N + 1
assert input_ids.shape == (B, expected_seq_len), (
f"build_ursa_inputs: input_ids shape={input_ids.shape} "
f"expected ({B},{expected_seq_len}). "
"txt_ids length or latents_shape may be wrong."
)
assert rope_pos.shape == (B, expected_seq_len, 3), (
f"build_ursa_inputs: rope_pos shape={rope_pos.shape} "
f"expected ({B},{expected_seq_len},3). "
"BOV/blk_pos alignment is off — check flex_rope.get_pos return shape."
)
return input_ids, rope_pos, N
# ---------------------------------------------------------------------------
# Visual logit extractor
# ---------------------------------------------------------------------------
def extract_visual_logits(
logits: torch.Tensor,
N: int,
codebook_size: int,
lm_head_size: int = None,
) -> torch.Tensor:
"""Slice and (if needed) project the transformer logits to [B, N, K].
PITFALL 2: The lm_head projects hidden states to lm_head_size (=64000),
NOT to the full vocab_size. We must never confuse text-vocab indices with
codebook indices. This function is the single gate that converts raw
transformer output to visual-codebook logits.
Slicing convention (mirrors URSAPipeline):
z = logits[:, -(N+1) : -1] # causal shift: BOV at -(N+1), last is EOS
If the last dimension already equals codebook_size, return z directly.
If the last dimension is larger (e.g. full vocab), slice the visual region.
Otherwise raise a descriptive error so the caller can fix the config.
Args:
logits: Raw transformer output, shape [B, L+N+1, D].
N: Number of visual tokens (T*H*W).
codebook_size: Expected number of codebook entries (scheduler.codebook_size).
lm_head_size: Deprecated alias for codebook_size; ignored if None.
Returns:
Tensor of shape [B, N, codebook_size].
"""
B_in = logits.size(0)
# PITFALL 2: causal slice – exactly as URSAPipeline uses it.
# logits[:, -(N+1):-1] extracts the N positions after the BOV token.
z = logits[:, -(N + 1) : -1] # [B, N, D]
# Verify sliced sequence length matches N.
assert z.size(1) == N, (
f"extract_visual_logits: slice produced seq_len={z.size(1)}, expected N={N}. "
"Logit sequence length may be shorter than N+1. "
"Check that input_ids was built with the correct latents_shape."
)
D = z.size(-1)
if D == codebook_size:
# Happy path: lm_head_size == codebook_size (default URSA config).
assert z.shape == (B_in, N, codebook_size), \
f"extract_visual_logits: z.shape={z.shape} expected ({B_in},{N},{codebook_size})"
return z
# If the head includes a text prefix (shouldn't happen with default config,
# but guard anyway).
if D > codebook_size:
lm_vocab_size = D - codebook_size
z_vis = z[..., lm_vocab_size:]
assert z_vis.shape == (B_in, N, codebook_size), \
f"extract_visual_logits (sliced): z_vis.shape={z_vis.shape}"
return z_vis
raise ValueError(
f"extract_visual_logits: unexpected logit last-dim={D} < codebook_size={codebook_size}. "
"Check transformer.config.lm_head_size and scheduler.codebook_size. "
f"logits.shape={logits.shape}"
)
# ---------------------------------------------------------------------------
# Corrupt helper (for p_init mixing)
# ---------------------------------------------------------------------------
def corrupt_tokens(tokens: torch.Tensor, r: float, K: int) -> torch.Tensor:
"""Replace a random fraction r of tokens with uniform random codes.
Used for the 20% p_init mixing strategy:
mask = Bernoulli(r)
corrupted = mask * randint(K) + (1-mask) * tokens
Args:
tokens: Long tensor of codebook indices, any shape.
r: Fraction of tokens to corrupt (0 < r < 1).
K: Codebook size.
Returns:
Corrupted token tensor, same shape and dtype as ``tokens``.
"""
mask = torch.bernoulli(torch.full_like(tokens, r, dtype=torch.float)).bool()
rand_codes = torch.randint(0, K, tokens.shape, device=tokens.device, dtype=tokens.dtype)
return torch.where(mask, rand_codes, tokens)
# ---------------------------------------------------------------------------
# KL / Jeffrey divergence helpers
# ---------------------------------------------------------------------------
def kl_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""KL(p || q) summed over last dimension, per-sample mean over tokens.
Args:
p: [B, N, K] probability tensor.
q: [B, N, K] probability tensor.
Returns:
[B] per-sample KL divergence (mean over N tokens).
"""
p = p.clamp(min=eps)
q = q.clamp(min=eps)
return (p * (p.log() - q.log())).sum(-1).mean(-1) # [B]
def jeffrey_divergence(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""Symmetric KL (Jeffrey): KL(p||q) + KL(q||p), per-sample mean over tokens.
Returns:
[B] per-sample Jeffrey divergence.
"""
return kl_divergence(p, q, eps) + kl_divergence(q, p, eps)
# ---------------------------------------------------------------------------
# Timestep curriculum
# ---------------------------------------------------------------------------
def sample_t_curriculum(
B: int,
device: torch.device,
step: int,
warmup_steps: int = 10_000,
) -> torch.Tensor:
"""Sample training timesteps with a curriculum biased toward large t early on.
- For the first ``warmup_steps`` steps, use t = 1 - (1-u)^2 (biased high).
- After warmup, fall back to a near-uniform u sampled straight from [0, 1).
- t is clamped to [0.05, 0.995] to avoid degenerate paths.
Returns:
[B] float tensor of continuous timesteps.
"""
u = torch.rand(B, device=device)
if step < warmup_steps:
t = 1.0 - (1.0 - u) ** 2 # squish toward 1 (data end)
else:
t = u
return t.clamp(0.05, 0.995)