| """ |
| ViL-DLM: Vision xLSTM Diffusion Language Model |
| |
| Architecture: |
| [Image] → ViL Encoder → MLP Projector → [Visual Tokens] |
| [Visual Tokens] + [Text Tokens (masked)] → Bidirectional Diffusion LM → Denoised Tokens |
| |
| Components: |
| 1. ViL (Vision xLSTM) - custom vision encoder with linear complexity |
| 2. MLP Projector - maps ViL features to LM embedding space |
| 3. Qwen3-0.6B Diffusion LM - bidirectional masked diffusion backbone (from dLLM) |
| |
| Training: |
| Stage 1: Train projector only (ViL frozen, LM frozen) on LLaVA-Pretrain |
| Stage 2: Full finetune on multimodal instruction data |
| Stage 3: + Knowledge distillation from Gemma 4 E2B teacher |
| |
| Diffusion Process (MDLM): |
| Forward: progressively mask tokens with [MASK] according to cosine schedule |
| Reverse: iteratively predict masked tokens using bidirectional attention |
| Loss: weighted cross-entropy on masked positions |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, Any, Tuple |
| from transformers import AutoModelForImageTextToText, AutoModelForMaskedLM, AutoTokenizer |
|
|
| from model_config import ViLEncoderConfig, ProjectorConfig, TrainingConfig |
| from vision_xlstm import VisionXLSTM, VisionProjector |
|
|
|
|
| class MDLMScheduler: |
| """ |
| Masked Diffusion Language Model noise scheduler. |
| Cosine schedule for masking probability. |
| """ |
| def __init__(self, num_steps=1000, mask_token_id=151643): |
| self.num_steps = num_steps |
| self.mask_token_id = mask_token_id |
| |
| def get_mask_ratio(self, t): |
| """Cosine masking schedule: ratio of tokens to mask at timestep t""" |
| |
| return torch.cos(t * math.pi / 2) |
| |
| def add_noise(self, input_ids, t): |
| """ |
| Forward diffusion: mask tokens according to timestep t. |
| |
| Args: |
| input_ids: [B, T] clean token ids |
| t: [B] timestep in [0, 1] |
| Returns: |
| noisy_ids: [B, T] with some tokens replaced by mask |
| mask: [B, T] boolean - True where tokens are masked |
| """ |
| B, T = input_ids.shape |
| device = input_ids.device |
| |
| |
| mask_ratio = 1.0 - self.get_mask_ratio(t) |
| mask_ratio = mask_ratio.unsqueeze(1).expand(B, T) |
| |
| |
| rand = torch.rand(B, T, device=device) |
| mask = rand < mask_ratio |
| |
| |
| noisy_ids = input_ids.clone() |
| noisy_ids[mask] = self.mask_token_id |
| |
| return noisy_ids, mask |
| |
| def sample_timesteps(self, batch_size, device): |
| """Sample random timesteps for training""" |
| return torch.rand(batch_size, device=device) |
|
|
|
|
| class ViLDLM(nn.Module): |
| """ |
| Vision xLSTM Diffusion Language Model. |
| |
| Combines: |
| - ViL encoder for image understanding |
| - MLP projector for modality alignment |
| - Qwen3-0.6B diffusion backbone for masked denoising |
| """ |
| |
| def __init__(self, config: TrainingConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| self.vision_encoder = VisionXLSTM(config.vil_encoder) |
| |
| |
| self.projector = VisionProjector(config.projector) |
| |
| |
| self.lm = None |
| self.tokenizer = None |
| |
| |
| self.scheduler = MDLMScheduler( |
| num_steps=config.diffusion.num_diffusion_steps, |
| mask_token_id=config.diffusion.mask_token_id |
| ) |
| |
| |
| |
| |
| def load_diffusion_lm(self, local_path: str = None): |
| """Load the pretrained diffusion LM backbone""" |
| model_path = local_path or self.config.diffusion_lm_id |
| print(f"Loading diffusion LM from {model_path}...") |
| self.lm = AutoModelForMaskedLM.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if self.config.bf16 else torch.float32, |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| ) |
| print(f"Loaded diffusion LM: {sum(p.numel() for p in self.lm.parameters()) / 1e6:.1f}M params") |
| return self |
| |
| def get_input_embeddings(self): |
| """Get the LM's input embedding layer""" |
| return self.lm.model.embed_tokens |
| |
| def prepare_multimodal_inputs( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| image_token_id: int = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Prepare multimodal input embeddings by: |
| 1. Encoding image with ViL |
| 2. Projecting to LM space |
| 3. Concatenating [visual_tokens, text_tokens] |
| |
| Returns: |
| inputs_embeds: [B, T_vis + T_text, D] |
| full_attention_mask: [B, T_vis + T_text] |
| """ |
| B = pixel_values.shape[0] |
| |
| |
| with torch.set_grad_enabled(self.training): |
| vision_features = self.vision_encoder.forward_features(pixel_values) |
| |
| |
| |
| visual_tokens = self.projector(vision_features) |
| |
| |
| |
| text_embeds = self.get_input_embeddings()(input_ids) |
| |
| |
| |
| target_dtype = text_embeds.dtype |
| visual_tokens = visual_tokens.to(dtype=target_dtype) |
| |
| |
| inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1) |
| |
| |
| num_vis = visual_tokens.shape[1] |
| vis_mask = torch.ones(B, num_vis, device=attention_mask.device, dtype=attention_mask.dtype) |
| full_attention_mask = torch.cat([vis_mask, attention_mask], dim=1) |
| |
| return inputs_embeds, full_attention_mask |
| |
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Training forward pass with MDLM diffusion loss. |
| |
| 1. Sample random timestep t |
| 2. Mask tokens according to t (forward diffusion) |
| 3. Encode image + masked text through model |
| 4. Compute cross-entropy loss on masked positions |
| """ |
| B, T = input_ids.shape |
| device = input_ids.device |
| |
| if labels is None: |
| labels = input_ids.clone() |
| |
| |
| t = self.scheduler.sample_timesteps(B, device) |
| |
| |
| noisy_ids, noise_mask = self.scheduler.add_noise(input_ids, t) |
| |
| |
| inputs_embeds, full_attention_mask = self.prepare_multimodal_inputs( |
| pixel_values=pixel_values, |
| input_ids=noisy_ids, |
| attention_mask=attention_mask, |
| ) |
| |
| |
| outputs = self.lm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=full_attention_mask, |
| ) |
| |
| |
| num_vis = self.config.vil_encoder.num_patches |
| text_logits = outputs.logits[:, num_vis:, :] |
| |
| |
| |
| loss_mask = noise_mask.float() |
| |
| if loss_mask.sum() == 0: |
| |
| loss = torch.tensor(0.0, device=device, requires_grad=True) |
| else: |
| |
| logits_flat = text_logits.reshape(-1, text_logits.shape[-1]) |
| labels_flat = labels.reshape(-1) |
| loss_flat = F.cross_entropy(logits_flat, labels_flat, reduction='none') |
| loss_flat = loss_flat.reshape(B, T) |
| |
| |
| loss = (loss_flat * loss_mask).sum() / loss_mask.sum() |
| |
| return { |
| 'loss': loss, |
| 'logits': text_logits, |
| 'noise_mask': noise_mask, |
| 't': t, |
| } |
| |
| def freeze_vision_encoder(self): |
| """Freeze ViL encoder (Stage 1)""" |
| for param in self.vision_encoder.parameters(): |
| param.requires_grad = False |
| |
| def unfreeze_vision_encoder(self): |
| """Unfreeze ViL encoder (Stage 2+)""" |
| for param in self.vision_encoder.parameters(): |
| param.requires_grad = True |
| |
| def freeze_lm(self): |
| """Freeze diffusion LM backbone (Stage 1)""" |
| for param in self.lm.parameters(): |
| param.requires_grad = False |
| |
| def unfreeze_lm(self): |
| """Unfreeze diffusion LM backbone (Stage 2+)""" |
| for param in self.lm.parameters(): |
| param.requires_grad = True |
| |
| def get_parameter_groups(self): |
| """Get parameter groups with different learning rates""" |
| groups = [ |
| { |
| 'params': [p for p in self.vision_encoder.parameters() if p.requires_grad], |
| 'lr': self.config.vil_learning_rate, |
| 'name': 'vision_encoder' |
| }, |
| { |
| 'params': [p for p in self.projector.parameters() if p.requires_grad], |
| 'lr': self.config.projector_learning_rate, |
| 'name': 'projector' |
| }, |
| { |
| 'params': [p for p in self.lm.parameters() if p.requires_grad], |
| 'lr': self.config.learning_rate, |
| 'name': 'diffusion_lm' |
| }, |
| ] |
| return [g for g in groups if len(g['params']) > 0] |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| pixel_values: torch.Tensor, |
| prompt_ids: Optional[torch.Tensor] = None, |
| max_new_tokens: int = 128, |
| num_steps: int = 64, |
| temperature: float = 1.0, |
| ) -> torch.Tensor: |
| """ |
| Generate text from image using iterative masked diffusion denoising. |
| |
| Steps: |
| 1. Start with all-masked output tokens |
| 2. At each step, predict all tokens, unmask most confident ones |
| 3. Repeat until all tokens are unmasked |
| """ |
| self.eval() |
| B = pixel_values.shape[0] |
| device = pixel_values.device |
| |
| |
| output_ids = torch.full( |
| (B, max_new_tokens), |
| self.scheduler.mask_token_id, |
| device=device, dtype=torch.long |
| ) |
| |
| |
| if prompt_ids is not None: |
| full_ids = torch.cat([prompt_ids, output_ids], dim=1) |
| prompt_len = prompt_ids.shape[1] |
| else: |
| full_ids = output_ids |
| prompt_len = 0 |
| |
| T_total = full_ids.shape[1] |
| attention_mask = torch.ones(B, T_total, device=device) |
| |
| |
| tokens_per_step = max(1, max_new_tokens // num_steps) |
| |
| for step in range(num_steps): |
| |
| inputs_embeds, full_attn = self.prepare_multimodal_inputs( |
| pixel_values, full_ids, attention_mask |
| ) |
| outputs = self.lm(inputs_embeds=inputs_embeds, attention_mask=full_attn) |
| |
| num_vis = self.config.vil_encoder.num_patches |
| logits = outputs.logits[:, num_vis:, :] |
| |
| |
| gen_logits = logits[:, prompt_len:, :] |
| gen_ids = full_ids[:, prompt_len:] |
| |
| |
| is_masked = (gen_ids == self.scheduler.mask_token_id) |
| |
| if not is_masked.any(): |
| break |
| |
| |
| probs = F.softmax(gen_logits / temperature, dim=-1) |
| predicted = probs.argmax(dim=-1) |
| |
| |
| confidence = probs.max(dim=-1).values |
| confidence[~is_masked] = float('inf') |
| |
| |
| num_to_unmask = min(tokens_per_step, is_masked.sum().item()) |
| if num_to_unmask > 0: |
| |
| _, topk_idx = confidence.topk(num_to_unmask, dim=-1, largest=True) |
| |
| |
| for b in range(B): |
| for idx in topk_idx[b]: |
| if is_masked[b, idx]: |
| full_ids[b, prompt_len + idx] = predicted[b, idx] |
| |
| return full_ids[:, prompt_len:] |
| |
| def count_parameters(self): |
| """Count parameters by component""" |
| vil_params = sum(p.numel() for p in self.vision_encoder.parameters()) |
| proj_params = sum(p.numel() for p in self.projector.parameters()) |
| lm_params = sum(p.numel() for p in self.lm.parameters()) if self.lm else 0 |
| |
| total = vil_params + proj_params + lm_params |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
| return { |
| 'vision_encoder': vil_params, |
| 'projector': proj_params, |
| 'diffusion_lm': lm_params, |
| 'total': total, |
| 'trainable': trainable, |
| } |
|
|
|
|
| class ViLDLMWithDistillation(ViLDLM): |
| """ |
| ViL-DLM with knowledge distillation from Gemma 4 E2B teacher. |
| |
| Real Stage 3 uses sparse cross-tokenizer KD targets that are |
| prepared offline with the teacher and cached in the student's |
| token space. |
| """ |
| |
| def __init__(self, config: TrainingConfig): |
| super().__init__(config) |
| self.teacher = None |
| self.teacher_processor = None |
| self.kd_config = config.distillation |
| |
| def load_teacher(self): |
| """Load Gemma 4 E2B as teacher (quantized for memory)""" |
| from transformers import AutoProcessor |
| |
| print(f"Loading teacher: {self.kd_config.teacher_model_id}...") |
| |
| if self.kd_config.teacher_quantize: |
| from transformers import BitsAndBytesConfig |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_quant_type="nf4", |
| ) |
| self.teacher = AutoModelForImageTextToText.from_pretrained( |
| self.kd_config.teacher_model_id, |
| quantization_config=bnb_config, |
| device_map="auto", |
| ) |
| else: |
| self.teacher = AutoModelForImageTextToText.from_pretrained( |
| self.kd_config.teacher_model_id, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| |
| self.teacher_processor = AutoProcessor.from_pretrained( |
| self.kd_config.teacher_model_id |
| ) |
| |
| |
| for param in self.teacher.parameters(): |
| param.requires_grad = False |
| self.teacher.eval() |
| |
| print(f"Teacher loaded: {sum(p.numel() for p in self.teacher.parameters()) / 1e9:.1f}B params") |
| |
| def compute_sparse_kd_loss( |
| self, |
| student_logits: torch.Tensor, |
| noise_mask: torch.Tensor, |
| kd_targets: Optional[list[dict[str, Any]]], |
| ) -> torch.Tensor: |
| """Compute sparse KL in the student's token space.""" |
| if not kd_targets: |
| return torch.tensor(0.0, device=student_logits.device) |
|
|
| temperature = self.kd_config.temperature |
| losses = [] |
| for entry in kd_targets: |
| batch_idx = int(entry["batch_idx"]) |
| position = int(entry["position"]) |
| if position >= student_logits.shape[1]: |
| continue |
| if not bool(noise_mask[batch_idx, position].item()): |
| continue |
| candidate_token_ids = torch.tensor( |
| entry["candidate_token_ids"], |
| device=student_logits.device, |
| dtype=torch.long, |
| ) |
| teacher_probs = torch.tensor( |
| entry["teacher_probs"], |
| device=student_logits.device, |
| dtype=student_logits.dtype, |
| ) |
| gathered = student_logits[batch_idx, position, candidate_token_ids] |
| student_log_probs = F.log_softmax(gathered / temperature, dim=-1) |
| losses.append( |
| F.kl_div(student_log_probs, teacher_probs, reduction="batchmean") * (temperature ** 2) |
| ) |
|
|
| if not losses: |
| return torch.tensor(0.0, device=student_logits.device) |
| return torch.stack(losses).mean() |
| |
| def forward_with_distillation( |
| self, |
| pixel_values: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| kd_targets: Optional[list[dict[str, Any]]] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """Forward with diffusion loss plus sparse cached KD targets.""" |
| |
| |
| student_outputs = self.forward( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| ) |
| |
| diffusion_loss = student_outputs['loss'] |
| kd_loss = self.compute_sparse_kd_loss( |
| student_logits=student_outputs["logits"], |
| noise_mask=student_outputs["noise_mask"], |
| kd_targets=kd_targets, |
| ) |
| |
| |
| alpha = self.kd_config.alpha_kd |
| total_loss = (1 - alpha) * diffusion_loss + alpha * kd_loss |
| |
| return { |
| 'loss': total_loss, |
| 'diffusion_loss': diffusion_loss, |
| 'kd_loss': kd_loss, |
| 'logits': student_outputs['logits'], |
| 'noise_mask': student_outputs['noise_mask'], |
| 't': student_outputs['t'], |
| } |
|
|