| """ |
| Training Configuration Classes |
| |
| Contains dataclasses for LoRA and training configurations. |
| """ |
|
|
| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
|
|
| @dataclass |
| class LoRAConfig: |
| """Configuration for LoRA (Low-Rank Adaptation) training. |
| |
| Attributes: |
| r: LoRA rank (dimension of low-rank matrices) |
| alpha: LoRA scaling factor (alpha/r determines the scaling) |
| dropout: Dropout probability for LoRA layers |
| target_modules: List of module names to apply LoRA to |
| bias: Whether to train bias parameters ("none", "all", or "lora_only") |
| """ |
| r: int = 8 |
| alpha: int = 16 |
| dropout: float = 0.1 |
| target_modules: List[str] = field(default_factory=lambda: [ |
| "q_proj", "k_proj", "v_proj", "o_proj" |
| ]) |
| bias: str = "none" |
| |
| def to_dict(self): |
| """Convert to dictionary for PEFT config.""" |
| return { |
| "r": self.r, |
| "lora_alpha": self.alpha, |
| "lora_dropout": self.dropout, |
| "target_modules": self.target_modules, |
| "bias": self.bias, |
| } |
|
|
|
|
| @dataclass |
| class TrainingConfig: |
| """Configuration for LoRA training process. |
| |
| Training uses: |
| - BFloat16 precision (only supported precision) |
| - Discrete timesteps from turbo shift=3.0 schedule (8 steps) |
| - Randomly samples one of 8 timesteps per training step: |
| [1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3] |
| |
| Attributes: |
| shift: Timestep shift factor (fixed at 3.0 for turbo model) |
| num_inference_steps: Number of inference steps (fixed at 8 for turbo) |
| learning_rate: Initial learning rate |
| batch_size: Training batch size |
| gradient_accumulation_steps: Number of gradient accumulation steps |
| max_epochs: Maximum number of training epochs |
| save_every_n_epochs: Save checkpoint every N epochs |
| warmup_steps: Number of warmup steps for learning rate scheduler |
| weight_decay: Weight decay for optimizer |
| max_grad_norm: Maximum gradient norm for clipping |
| mixed_precision: Always "bf16" (only supported precision) |
| seed: Random seed for reproducibility |
| output_dir: Directory to save checkpoints and logs |
| """ |
| |
| shift: float = 3.0 |
| num_inference_steps: int = 8 |
| learning_rate: float = 1e-4 |
| batch_size: int = 1 |
| gradient_accumulation_steps: int = 4 |
| max_epochs: int = 100 |
| save_every_n_epochs: int = 10 |
| warmup_steps: int = 100 |
| weight_decay: float = 0.01 |
| max_grad_norm: float = 1.0 |
| mixed_precision: str = "bf16" |
| seed: int = 42 |
| output_dir: str = "./lora_output" |
| |
| |
| num_workers: int = 4 |
| pin_memory: bool = True |
| |
| |
| log_every_n_steps: int = 10 |
| |
| def to_dict(self): |
| """Convert to dictionary.""" |
| return { |
| "shift": self.shift, |
| "num_inference_steps": self.num_inference_steps, |
| "learning_rate": self.learning_rate, |
| "batch_size": self.batch_size, |
| "gradient_accumulation_steps": self.gradient_accumulation_steps, |
| "max_epochs": self.max_epochs, |
| "save_every_n_epochs": self.save_every_n_epochs, |
| "warmup_steps": self.warmup_steps, |
| "weight_decay": self.weight_decay, |
| "max_grad_norm": self.max_grad_norm, |
| "mixed_precision": self.mixed_precision, |
| "seed": self.seed, |
| "output_dir": self.output_dir, |
| "num_workers": self.num_workers, |
| "pin_memory": self.pin_memory, |
| "log_every_n_steps": self.log_every_n_steps, |
| } |
|
|