ViL-DLM-0.6B / code /vil_dlm_model.py
omar-ah's picture
Implement stage-aware real-run training pipeline
0d77b0a
"""
ViL-DLM: Vision xLSTM Diffusion Language Model
Architecture:
[Image] → ViL Encoder → MLP Projector → [Visual Tokens]
[Visual Tokens] + [Text Tokens (masked)] → Bidirectional Diffusion LM → Denoised Tokens
Components:
1. ViL (Vision xLSTM) - custom vision encoder with linear complexity
2. MLP Projector - maps ViL features to LM embedding space
3. Qwen3-0.6B Diffusion LM - bidirectional masked diffusion backbone (from dLLM)
Training:
Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain
Stage 2: Full finetune on multimodal instruction data
Stage 3: + Knowledge distillation from Gemma 4 E2B teacher
Diffusion Process (MDLM):
Forward: progressively mask tokens with [MASK] according to cosine schedule
Reverse: iteratively predict masked tokens using bidirectional attention
Loss: weighted cross-entropy on masked positions
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, Tuple
from transformers import AutoModelForImageTextToText, AutoModelForMaskedLM, AutoTokenizer
from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig
from vision_xlstm import VisionXLSTM, VisionProjector
class MDLMScheduler:
"""
Masked Diffusion Language Model noise scheduler.
Cosine schedule for masking probability.
"""
def __init__(self, num_steps=1000, mask_token_id=151643):
self.num_steps = num_steps
self.mask_token_id = mask_token_id
def get_mask_ratio(self, t):
"""Cosine masking schedule: ratio of tokens to mask at timestep t"""
# t in [0, 1]: 0 = clean, 1 = fully masked
return torch.cos(t * math.pi / 2) # mask_ratio decreases as t→0
def add_noise(self, input_ids, t):
"""
Forward diffusion: mask tokens according to timestep t.
Args:
input_ids: [B, T] clean token ids
t: [B] timestep in [0, 1]
Returns:
noisy_ids: [B, T] with some tokens replaced by mask
mask: [B, T] boolean - True where tokens are masked
"""
B, T = input_ids.shape
device = input_ids.device
# Get mask ratio for each sample
mask_ratio = 1.0 - self.get_mask_ratio(t) # Higher t → more masking
mask_ratio = mask_ratio.unsqueeze(1).expand(B, T) # [B, T]
# Sample mask: each token independently masked with probability mask_ratio
rand = torch.rand(B, T, device=device)
mask = rand < mask_ratio # True = masked
# Replace masked tokens
noisy_ids = input_ids.clone()
noisy_ids[mask] = self.mask_token_id
return noisy_ids, mask
def sample_timesteps(self, batch_size, device):
"""Sample random timesteps for training"""
return torch.rand(batch_size, device=device)
class ViLDLM(nn.Module):
"""
Vision xLSTM Diffusion Language Model.
Combines:
- ViL encoder for image understanding
- MLP projector for modality alignment
- Qwen3-0.6B diffusion backbone for masked denoising
"""
def __init__(self, config: TrainingConfig):
super().__init__()
self.config = config
# 1. Vision Encoder (ViL)
self.vision_encoder = VisionXLSTM(config.vil_encoder)
# 2. MLP Projector
self.projector = VisionProjector(config.projector)
# 3. Diffusion LM backbone (loaded from pretrained)
self.lm = None # Will be loaded separately
self.tokenizer = None
# 4. Diffusion scheduler
self.scheduler = MDLMScheduler(
num_steps=config.diffusion.num_diffusion_steps,
mask_token_id=config.diffusion.mask_token_id
)
# 5. Special token embedding for image placeholder
# We'll use the LM's embedding layer directly
def load_diffusion_lm(self, local_path: str = None):
"""Load the pretrained diffusion LM backbone"""
model_path = local_path or self.config.diffusion_lm_id
print(f"Loading diffusion LM from {model_path}...")
self.lm = AutoModelForMaskedLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float32,
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
print(f"Loaded diffusion LM: {sum(p.numel() for p in self.lm.parameters()) / 1e6:.1f}M params")
return self
def get_input_embeddings(self):
"""Get the LM's input embedding layer"""
return self.lm.model.embed_tokens
def prepare_multimodal_inputs(
self,
pixel_values: torch.Tensor, # [B, C, H, W]
input_ids: torch.Tensor, # [B, T_text]
attention_mask: torch.Tensor, # [B, T_text]
image_token_id: int = None, # token id marking where image goes
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare multimodal input embeddings by:
1. Encoding image with ViL
2. Projecting to LM space
3. Concatenating [visual_tokens, text_tokens]
Returns:
inputs_embeds: [B, T_vis + T_text, D]
full_attention_mask: [B, T_vis + T_text]
"""
B = pixel_values.shape[0]
# Encode image
with torch.set_grad_enabled(self.training):
vision_features = self.vision_encoder.forward_features(pixel_values)
# vision_features: [B, num_patches, vil_dim]
# Project to LM space
visual_tokens = self.projector(vision_features)
# visual_tokens: [B, num_patches, lm_dim]
# Get text embeddings
text_embeds = self.get_input_embeddings()(input_ids)
# text_embeds: [B, T_text, lm_dim]
# Ensure matching dtype (ViL may be float32, LM may be bfloat16)
target_dtype = text_embeds.dtype
visual_tokens = visual_tokens.to(dtype=target_dtype)
# Concatenate: [visual_tokens | text_tokens]
inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
# Build attention mask: all visual tokens are always visible
num_vis = visual_tokens.shape[1]
vis_mask = torch.ones(B, num_vis, device=attention_mask.device, dtype=attention_mask.dtype)
full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1)
return inputs_embeds, full_attention_mask
def forward(
self,
pixel_values: torch.Tensor, # [B, C, H, W]
input_ids: torch.Tensor, # [B, T] clean text tokens
attention_mask: torch.Tensor, # [B, T]
labels: Optional[torch.Tensor] = None, # [B, T] for loss computation
) -> Dict[str, torch.Tensor]:
"""
Training forward pass with MDLM diffusion loss.
1. Sample random timestep t
2. Mask tokens according to t (forward diffusion)
3. Encode image + masked text through model
4. Compute cross-entropy loss on masked positions
"""
B, T = input_ids.shape
device = input_ids.device
if labels is None:
labels = input_ids.clone()
# Sample timesteps
t = self.scheduler.sample_timesteps(B, device)
# Forward diffusion: mask text tokens
noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t)
# Prepare multimodal inputs with noisy text
inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs(
pixel_values=pixel_values,
input_ids=noisy_ids,
attention_mask=attention_mask,
)
# Forward through diffusion LM
outputs = self.lm(
inputs_embeds=inputs_embeds,
attention_mask=full_attention_mask,
)
# Get logits for text portion only (skip visual token positions)
num_vis = self.config.vil_encoder.num_patches
text_logits = outputs.logits[:, num_vis:, :] # [B, T, vocab_size]
# Compute loss only on masked positions (MDLM objective)
# Weight by timestep: positions masked at higher t get higher weight
loss_mask = noise_mask.float()
if loss_mask.sum() == 0:
# Edge case: no masked tokens
loss = torch.tensor(0.0, device=device, requires_grad=True)
else:
# Cross-entropy on masked positions
logits_flat = text_logits.reshape(-1, text_logits.shape[-1])
labels_flat = labels.reshape(-1)
loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none')
loss_flat = loss_flat.reshape(B, T)
# Apply mask: only count loss on masked tokens
loss = (loss_flat * loss_mask).sum() / loss_mask.sum()
return {
'loss': loss,
'logits': text_logits,
'noise_mask': noise_mask,
't': t,
}
def freeze_vision_encoder(self):
"""Freeze ViL encoder (Stage 1)"""
for param in self.vision_encoder.parameters():
param.requires_grad = False
def unfreeze_vision_encoder(self):
"""Unfreeze ViL encoder (Stage 2+)"""
for param in self.vision_encoder.parameters():
param.requires_grad = True
def freeze_lm(self):
"""Freeze diffusion LM backbone (Stage 1)"""
for param in self.lm.parameters():
param.requires_grad = False
def unfreeze_lm(self):
"""Unfreeze diffusion LM backbone (Stage 2+)"""
for param in self.lm.parameters():
param.requires_grad = True
def get_parameter_groups(self):
"""Get parameter groups with different learning rates"""
groups = [
{
'params': [p for p in self.vision_encoder.parameters() if p.requires_grad],
'lr': self.config.vil_learning_rate,
'name': 'vision_encoder'
},
{
'params': [p for p in self.projector.parameters() if p.requires_grad],
'lr': self.config.projector_learning_rate,
'name': 'projector'
},
{
'params': [p for p in self.lm.parameters() if p.requires_grad],
'lr': self.config.learning_rate,
'name': 'diffusion_lm'
},
]
return [g for g in groups if len(g['params']) > 0]
@torch.no_grad()
def generate(
self,
pixel_values: torch.Tensor,
prompt_ids: Optional[torch.Tensor] = None,
max_new_tokens: int = 128,
num_steps: int = 64,
temperature: float = 1.0,
) -> torch.Tensor:
"""
Generate text from image using iterative masked diffusion denoising.
Steps:
1. Start with all-masked output tokens
2. At each step, predict all tokens, unmask most confident ones
3. Repeat until all tokens are unmasked
"""
self.eval()
B = pixel_values.shape[0]
device = pixel_values.device
# Start with all masked tokens
output_ids = torch.full(
(B, max_new_tokens),
self.scheduler.mask_token_id,
device=device, dtype=torch.long
)
# If prompt provided, prepend it
if prompt_ids is not None:
full_ids = torch.cat([prompt_ids, output_ids], dim=1)
prompt_len = prompt_ids.shape[1]
else:
full_ids = output_ids
prompt_len = 0
T_total = full_ids.shape[1]
attention_mask = torch.ones(B, T_total, device=device)
# Iterative denoising
tokens_per_step = max(1, max_new_tokens // num_steps)
for step in range(num_steps):
# Get predictions
inputs_embeds, full_attn = self.prepare_multimodal_inputs(
pixel_values, full_ids, attention_mask
)
outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attn)
num_vis = self.config.vil_encoder.num_patches
logits = outputs.logits[:, num_vis:, :] # text portion
# Only update masked positions in the generation part
gen_logits = logits[:, prompt_len:, :] # [B, max_new_tokens, vocab]
gen_ids = full_ids[:, prompt_len:]
# Find masked positions
is_masked = (gen_ids == self.scheduler.mask_token_id)
if not is_masked.any():
break
# Get probabilities
probs = F.softmax(gen_logits / temperature, dim=-1)
predicted = probs.argmax(dim=-1) # [B, max_new_tokens]
# Confidence = max probability
confidence = probs.max(dim=-1).values # [B, max_new_tokens]
confidence[~is_masked] = float('inf') # don't re-unmask
# Unmask top-k most confident tokens
num_to_unmask = min(tokens_per_step, is_masked.sum().item())
if num_to_unmask > 0:
# Get indices of most confident masked positions
_, topk_idx = confidence.topk(num_to_unmask, dim=-1, largest=True)
# Unmask these positions
for b in range(B):
for idx in topk_idx[b]:
if is_masked[b, idx]:
full_ids[b, prompt_len + idx] = predicted[b, idx]
return full_ids[:, prompt_len:] # Return generated tokens only
def count_parameters(self):
"""Count parameters by component"""
vil_params = sum(p.numel() for p in self.vision_encoder.parameters())
proj_params = sum(p.numel() for p in self.projector.parameters())
lm_params = sum(p.numel() for p in self.lm.parameters()) if self.lm else 0
total = vil_params + proj_params + lm_params
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'vision_encoder': vil_params,
'projector': proj_params,
'diffusion_lm': lm_params,
'total': total,
'trainable': trainable,
}
class ViLDLMWithDistillation(ViLDLM):
"""
ViL-DLM with knowledge distillation from Gemma 4 E2B teacher.
Real Stage 3 uses sparse cross-tokenizer KD targets that are
prepared offline with the teacher and cached in the student's
token space.
"""
def __init__(self, config: TrainingConfig):
super().__init__(config)
self.teacher = None
self.teacher_processor = None
self.kd_config = config.distillation
def load_teacher(self):
"""Load Gemma 4 E2B as teacher (quantized for memory)"""
from transformers import AutoProcessor
print(f"Loading teacher: {self.kd_config.teacher_model_id}...")
if self.kd_config.teacher_quantize:
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
self.teacher = AutoModelForImageTextToText.from_pretrained(
self.kd_config.teacher_model_id,
quantization_config=bnb_config,
device_map="auto",
)
else:
self.teacher = AutoModelForImageTextToText.from_pretrained(
self.kd_config.teacher_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
self.teacher_processor = AutoProcessor.from_pretrained(
self.kd_config.teacher_model_id
)
# Freeze teacher
for param in self.teacher.parameters():
param.requires_grad = False
self.teacher.eval()
print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params")
def compute_sparse_kd_loss(
self,
student_logits: torch.Tensor,
noise_mask: torch.Tensor,
kd_targets: Optional[list[dict[str, Any]]],
) -> torch.Tensor:
"""Compute sparse KL in the student's token space."""
if not kd_targets:
return torch.tensor(0.0, device=student_logits.device)
temperature = self.kd_config.temperature
losses = []
for entry in kd_targets:
batch_idx = int(entry["batch_idx"])
position = int(entry["position"])
if position >= student_logits.shape[1]:
continue
if not bool(noise_mask[batch_idx, position].item()):
continue
candidate_token_ids = torch.tensor(
entry["candidate_token_ids"],
device=student_logits.device,
dtype=torch.long,
)
teacher_probs = torch.tensor(
entry["teacher_probs"],
device=student_logits.device,
dtype=student_logits.dtype,
)
gathered = student_logits[batch_idx, position, candidate_token_ids]
student_log_probs = F.log_softmax(gathered / temperature, dim=-1)
losses.append(
F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
)
if not losses:
return torch.tensor(0.0, device=student_logits.device)
return torch.stack(losses).mean()
def forward_with_distillation(
self,
pixel_values: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
kd_targets: Optional[list[dict[str, Any]]] = None,
) -> Dict[str, torch.Tensor]:
"""Forward with diffusion loss plus sparse cached KD targets."""
# Student forward (diffusion loss)
student_outputs = self.forward(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
diffusion_loss = student_outputs['loss']
kd_loss = self.compute_sparse_kd_loss(
student_logits=student_outputs["logits"],
noise_mask=student_outputs["noise_mask"],
kd_targets=kd_targets,
)
# Combined loss
alpha = self.kd_config.alpha_kd
total_loss = (1 - alpha) * diffusion_loss + alpha * kd_loss
return {
'loss': total_loss,
'diffusion_loss': diffusion_loss,
'kd_loss': kd_loss,
'logits': student_outputs['logits'],
'noise_mask': student_outputs['noise_mask'],
't': student_outputs['t'],
}