ViL-DLM-0.6B / code /model_config.py
omar-ah's picture
Fix The Cauldron aokvqa config name
f089e8f
"""
ViL-DLM: Vision xLSTM + Diffusion Language Model
Architecture Configuration
"""
from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class ViLEncoderConfig:
"""Vision xLSTM (ViL) encoder configuration"""
vision_backbone: str = "vil2-small"
pretrained: bool = True
img_size: int = 224
patch_size: int = 16
in_channels: int = 3
dim: int = 384 # patch feature dim for vil-small / vil2-small
depth: int = 12 # VisionLSTM2 block-pairs; v1 vil-small internally uses 24
mlstm_dim_mult: int = 2 # mLSTM internal dim = 2 * dim
conv_kernel_size: int = 3 # QK Conv2D kernel
bidirectional: bool = True # alternating scan directions
dropout: float = 0.0
@property
def num_patches(self):
return (self.img_size // self.patch_size) ** 2 # 196 for 224/16
@property
def num_params_approx(self):
# Rough estimate: each mLSTM block has ~4 * dim * (2*dim) params for QKV + gates
per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4
return self.depth * per_block
@dataclass
class ProjectorConfig:
"""MLP projector: maps ViL features to LM embedding space"""
vil_dim: int = 384 # ViL-S output dim
lm_dim: int = 1024 # Qwen3-0.6B hidden_size
hidden_mult: int = 2 # projector hidden = lm_dim * hidden_mult
num_layers: int = 2 # 2-layer MLP (LaViDa/LLaDA-V standard)
activation: str = "gelu"
dropout: float = 0.0
@dataclass
class DiffusionConfig:
"""Masked diffusion (MDLM) training configuration"""
noise_schedule: str = "cosine" # cosine schedule (MDLM default)
mask_token_id: int = 151643 # Qwen3 pad/mask token
num_diffusion_steps: int = 1000 # training steps
inference_steps: int = 128 # sampling steps
remasking: str = "low_confidence" # remasking strategy
@dataclass
class DistillationConfig:
"""Knowledge distillation from Gemma 4 E2B teacher"""
teacher_model_id: str = "google/gemma-4-E2B-it"
teacher_quantize: bool = True # 4-bit quantization for memory
temperature: float = 2.0 # KD temperature
alpha_kd: float = 0.5 # weight for KD loss vs diffusion loss
alpha_vision_kd: float = 0.3 # weight for vision feature distillation
kd_top_k: int = 8 # sparse cross-tokenizer candidate set size
kd_positions_per_sample: int = 16
teacher_cache_dir: str = "./vil-dlm-output/teacher-cache"
@dataclass
class TrainingConfig:
"""Full training configuration"""
# Model
vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig)
projector: ProjectorConfig = field(default_factory=ProjectorConfig)
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
distillation: DistillationConfig = field(default_factory=DistillationConfig)
# Backbone
diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1"
# Training hyperparams (from dLLM + LLaDA-V + LFM2 recipes)
learning_rate: float = 1e-4
vil_learning_rate: float = 2e-6 # lower LR for vision encoder (LLaDA-V)
projector_learning_rate: float = 1e-3 # higher LR for projector (LLaDA-V Stage 1)
weight_decay: float = 0.05
warmup_ratio: float = 0.1
lr_scheduler: str = "cosine"
max_seq_len: int = 1024
per_device_train_batch_size: int = 4
gradient_accumulation_steps: int = 8 # effective batch = 32
num_epochs: int = 3
bf16: bool = True
gradient_checkpointing: bool = True
# Data
pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" # Stage 1: 558K
finetune_dataset: str = "HuggingFaceM4/the_cauldron" # Stage 2: rich multimodal
finetune_dataset_configs: List[str] = field(default_factory=lambda: [
"ai2d",
"vqav2",
"aokvqa",
"textvqa",
"docvqa",
"chartqa",
"textcaps",
"screen2words",
])
# Output
output_dir: str = "./vil-dlm-output"
hub_model_id: str = "omar-ah/ViL-DLM-0.6B"
push_to_hub: bool = False
# Stages
stage: str = "1" # 1, 2, 3a, 3b
def get_config(stage: str = "1") -> TrainingConfig:
config = TrainingConfig()
config.stage = stage
if stage == "1":
# Stage 1: Train projector only (ViL frozen, LM frozen)
config.learning_rate = 1e-3
config.num_epochs = 1
config.per_device_train_batch_size = 8
config.gradient_accumulation_steps = 4
elif stage == "2":
# Stage 2: Full model finetune (ViL + projector + LM)
config.learning_rate = 1e-5
config.vil_learning_rate = 2e-6
config.projector_learning_rate = 1e-5
config.num_epochs = 3
elif stage in {"3a", "3b"}:
# Stage 3: sparse cross-tokenizer distillation with Gemma 4
config.learning_rate = 1e-5
config.num_epochs = 2
config.distillation.alpha_kd = 0.5
return config