| | from __future__ import annotations |
| | from typing import Optional |
| |
|
| | from transformers import PretrainedConfig |
| |
|
| |
|
| | class RecursiveMLMConfig(PretrainedConfig): |
| | """ |
| | Configuration for RecursiveMaskedLM. |
| | |
| | Stores the base MLM config plus recursive refinement parameters. |
| | |
| | Convergence Schedule System |
| | --------------------------- |
| | The convergence schedule controls WHEN each position is allowed to converge |
| | to a confident prediction during iterative refinement. |
| | |
| | Schedule types: |
| | - "linear": All positions converge at the same rate (iteration-based only) |
| | - "causal": Early positions converge first, late positions last |
| | |
| | Effects (mechanisms to enforce the schedule): |
| | - temperature_max: Raise temperature for positions not yet allowed to converge |
| | - entropy_target_max: Force exact entropy via bisection search (two-sided, recommended) |
| | - entropy_floor_max: Force minimum entropy (one-sided, only raises) |
| | - smear_sigma_max: Spread probability across neighboring positions |
| | - noise_std_max: Add Gaussian noise to logits |
| | - iteration_rope_dim_fraction: Apply rotary embedding based on iteration progress |
| | |
| | Soft Embedding Methods |
| | ---------------------- |
| | Controls how logits are converted to soft embeddings for the next iteration: |
| | - "softmax": Standard softmax normalization (default). Creates sparse, probabilistic |
| | mixing but can cause gradient bottlenecks through the softmax Jacobian. |
| | - "l2_normalize": L2 normalize logits before mixing with embeddings. Removes the |
| | softmax bottleneck for smoother gradients through long recursion chains. |
| | - "none": No normalization - use raw logits directly. Warning: this can cause |
| | scale explosion without additional mechanisms like EMA accumulation. |
| | |
| | - soft_embedding_ema_step: Controls EMA blending with previous soft embeddings. |
| | 1.0 (default) = full update (no EMA), 0.1 = slow update (90% previous + 10% new). |
| | Formula: new = (1 - ema_step) * prev + ema_step * current |
| | |
| | Recursion Checkpointing |
| | ----------------------- |
| | Controls gradient flow through the entire recursion chain for memory-efficient training. |
| | |
| | Parameters: |
| | - use_recursion_checkpointing: Enable gradient checkpointing for iterations |
| | - loss_weight: Use "last_1" for final-iteration-only loss (learns convergence behavior) |
| | |
| | Flow Matching (CFM-inspired) |
| | ---------------------------- |
| | Replaces the old temperature-based self-distillation with a Continuous Flow Matching |
| | framework. Training inputs are interpolated on the probability simplex between random |
| | noise and the target one-hot, distillation gives the student a noisier (earlier-time) |
| | version of the same interpolation path, and inference uses a flow map update rule. |
| | |
| | Parameters: |
| | - flow_matching_enabled: Enable the flow matching framework |
| | - flow_matching_lambda: Weight of distillation KL loss relative to CE loss |
| | - flow_matching_t_distribution: How to sample time t ("logit_normal" or "uniform") |
| | - flow_matching_t_logit_mean: Mean of logit-normal distribution (-0.4 biases toward noisy) |
| | - flow_matching_t_logit_std: Std of logit-normal distribution |
| | - flow_matching_t_min: Minimum time value (clamp) |
| | - flow_matching_t_max: Maximum time value (clamp) |
| | - flow_matching_mask_scale: If True, scale mask_emb by (1-t); if False, binary mask signal |
| | |
| | Time levels are sampled independently per masked token. At t=0 the input is pure noise, |
| | at t=1 it is the clean target embedding. |
| | |
| | Self-Distillation (legacy, temperature-based) |
| | ---------------------------------------------- |
| | Kept for backward compatibility. Ignored when flow_matching_enabled=True. |
| | |
| | Parameters: |
| | - self_distillation_enabled: Enable the self-distillation KL loss |
| | - self_distillation_lambda: Weight of distillation loss relative to CE loss |
| | - self_distillation_temperature_min: Minimum degradation temperature |
| | - self_distillation_temperature_max: Maximum degradation temperature |
| | - self_distillation_temperature_distribution: How to sample temperature |
| | - self_distillation_teacher: Which logits to use as teacher ("first" or "last") |
| | """ |
| | model_type = "recursive-mlm" |
| |
|
| | def __init__( |
| | self, |
| | base_model_config: Optional[dict] = None, |
| | num_recursions: int = 8, |
| | normalization: str = "softmax", |
| | loss_weight: str = "linear", |
| | mask_token_id: Optional[int] = None, |
| | temperature: float = 1.0, |
| | gradient_steps: Optional[int] = None, |
| | |
| | schedule: str = "linear", |
| | causal_strength: float = 1.0, |
| | |
| | temperature_max: float = 0.0, |
| | entropy_target_max: float = 0.0, |
| | entropy_floor_max: float = 0.0, |
| | smear_sigma_max: float = 0.0, |
| | noise_std_max: float = 0.0, |
| | iteration_rope_dim_fraction: float = 0.0, |
| | use_recursion_checkpointing: bool = True, |
| | |
| | soft_embedding_method: str = "softmax", |
| | soft_embedding_ema_step: float = 1.0, |
| | |
| | flow_matching_enabled: bool = False, |
| | flow_matching_lambda: float = 0.5, |
| | flow_matching_t_distribution: str = "logit_normal", |
| | flow_matching_t_logit_mean: float = -0.4, |
| | flow_matching_t_logit_std: float = 1.0, |
| | flow_matching_t_min: float = 0.01, |
| | flow_matching_t_max: float = 0.99, |
| | flow_matching_noise_scale: float = 2.0, |
| | flow_matching_mask_scale: bool = False, |
| | |
| | self_distillation_enabled: bool = False, |
| | self_distillation_lambda: float = 0.5, |
| | self_distillation_temperature_min: float = 1.5, |
| | self_distillation_temperature_max: float = 10.0, |
| | self_distillation_temperature_distribution: str = "log_uniform", |
| | self_distillation_teacher: str = "first", |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.base_model_config = base_model_config |
| | self.num_recursions = num_recursions |
| | self.normalization = normalization |
| | self.loss_weight = loss_weight |
| | self.mask_token_id = mask_token_id |
| | self.temperature = temperature |
| | self.gradient_steps = gradient_steps |
| | |
| | self.schedule = schedule |
| | self.causal_strength = causal_strength |
| | |
| | self.temperature_max = temperature_max |
| | self.entropy_target_max = entropy_target_max |
| | self.entropy_floor_max = entropy_floor_max |
| | self.smear_sigma_max = smear_sigma_max |
| | self.noise_std_max = noise_std_max |
| | self.iteration_rope_dim_fraction = iteration_rope_dim_fraction |
| | |
| | self.use_recursion_checkpointing = use_recursion_checkpointing |
| | |
| | self.soft_embedding_method = soft_embedding_method |
| | self.soft_embedding_ema_step = soft_embedding_ema_step |
| | |
| | self.flow_matching_enabled = flow_matching_enabled |
| | self.flow_matching_lambda = flow_matching_lambda |
| | self.flow_matching_t_distribution = flow_matching_t_distribution |
| | self.flow_matching_t_logit_mean = flow_matching_t_logit_mean |
| | self.flow_matching_t_logit_std = flow_matching_t_logit_std |
| | self.flow_matching_t_min = flow_matching_t_min |
| | self.flow_matching_t_max = flow_matching_t_max |
| | self.flow_matching_noise_scale = flow_matching_noise_scale |
| | self.flow_matching_mask_scale = flow_matching_mask_scale |
| | |
| | self.self_distillation_enabled = self_distillation_enabled |
| | self.self_distillation_lambda = self_distillation_lambda |
| | self.self_distillation_temperature_min = self_distillation_temperature_min |
| | self.self_distillation_temperature_max = self_distillation_temperature_max |
| | self.self_distillation_temperature_distribution = self_distillation_temperature_distribution |
| | self.self_distillation_teacher = self_distillation_teacher |
| |
|
| | @classmethod |
| | def from_base_model_config( |
| | cls, |
| | base_config: PretrainedConfig, |
| | **kwargs, |
| | ) -> "RecursiveMLMConfig": |
| | """Create config from a base MLM's config.""" |
| | return cls( |
| | base_model_config=base_config.to_dict(), |
| | **kwargs, |
| | ) |
| |
|