| """ |
| ViL-DLM: Vision xLSTM + Diffusion Language Model |
| Architecture Configuration |
| """ |
|
|
| from dataclasses import dataclass, field |
| from typing import Optional, List |
|
|
|
|
| @dataclass |
| class ViLEncoderConfig: |
| """Vision xLSTM (ViL) encoder configuration""" |
| vision_backbone: str = "vil2-small" |
| pretrained: bool = True |
| img_size: int = 224 |
| patch_size: int = 16 |
| in_channels: int = 3 |
| dim: int = 384 |
| depth: int = 12 |
| mlstm_dim_mult: int = 2 |
| conv_kernel_size: int = 3 |
| bidirectional: bool = True |
| dropout: float = 0.0 |
| |
| @property |
| def num_patches(self): |
| return (self.img_size // self.patch_size) ** 2 |
| |
| @property |
| def num_params_approx(self): |
| |
| per_block = 4 * self.dim * (self.mlstm_dim_mult * self.dim) + self.dim * self.dim * 4 |
| return self.depth * per_block |
|
|
|
|
| @dataclass |
| class ProjectorConfig: |
| """MLP projector: maps ViL features to LM embedding space""" |
| vil_dim: int = 384 |
| lm_dim: int = 1024 |
| hidden_mult: int = 2 |
| num_layers: int = 2 |
| activation: str = "gelu" |
| dropout: float = 0.0 |
|
|
|
|
| @dataclass |
| class DiffusionConfig: |
| """Masked diffusion (MDLM) training configuration""" |
| noise_schedule: str = "cosine" |
| mask_token_id: int = 151643 |
| num_diffusion_steps: int = 1000 |
| inference_steps: int = 128 |
| remasking: str = "low_confidence" |
|
|
|
|
| @dataclass |
| class DistillationConfig: |
| """Knowledge distillation from Gemma 4 E2B teacher""" |
| teacher_model_id: str = "google/gemma-4-E2B-it" |
| teacher_quantize: bool = True |
| temperature: float = 2.0 |
| alpha_kd: float = 0.5 |
| alpha_vision_kd: float = 0.3 |
| kd_top_k: int = 8 |
| kd_positions_per_sample: int = 16 |
| teacher_cache_dir: str = "./vil-dlm-output/teacher-cache" |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Full training configuration""" |
| |
| vil_encoder: ViLEncoderConfig = field(default_factory=ViLEncoderConfig) |
| projector: ProjectorConfig = field(default_factory=ProjectorConfig) |
| diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) |
| distillation: DistillationConfig = field(default_factory=DistillationConfig) |
| |
| |
| diffusion_lm_id: str = "dllm-hub/Qwen3-0.6B-diffusion-mdlm-v0.1" |
| |
| |
| learning_rate: float = 1e-4 |
| vil_learning_rate: float = 2e-6 |
| projector_learning_rate: float = 1e-3 |
| weight_decay: float = 0.05 |
| warmup_ratio: float = 0.1 |
| lr_scheduler: str = "cosine" |
| |
| max_seq_len: int = 1024 |
| per_device_train_batch_size: int = 4 |
| gradient_accumulation_steps: int = 8 |
| num_epochs: int = 3 |
| |
| bf16: bool = True |
| gradient_checkpointing: bool = True |
| |
| |
| pretrain_dataset: str = "liuhaotian/LLaVA-Pretrain" |
| finetune_dataset: str = "HuggingFaceM4/the_cauldron" |
| finetune_dataset_configs: List[str] = field(default_factory=lambda: [ |
| "ai2d", |
| "vqav2", |
| "aokvqa", |
| "textvqa", |
| "docvqa", |
| "chartqa", |
| "textcaps", |
| "screen2words", |
| ]) |
|
|
| |
| output_dir: str = "./vil-dlm-output" |
| hub_model_id: str = "omar-ah/ViL-DLM-0.6B" |
| push_to_hub: bool = False |
| |
| |
| stage: str = "1" |
|
|
|
|
| def get_config(stage: str = "1") -> TrainingConfig: |
| config = TrainingConfig() |
| config.stage = stage |
| |
| if stage == "1": |
| |
| config.learning_rate = 1e-3 |
| config.num_epochs = 1 |
| config.per_device_train_batch_size = 8 |
| config.gradient_accumulation_steps = 4 |
| elif stage == "2": |
| |
| config.learning_rate = 1e-5 |
| config.vil_learning_rate = 2e-6 |
| config.projector_learning_rate = 1e-5 |
| config.num_epochs = 3 |
| elif stage in {"3a", "3b"}: |
| |
| config.learning_rate = 1e-5 |
| config.num_epochs = 2 |
| config.distillation.alpha_kd = 0.5 |
| |
| return config |
|
|