| |
| """Configuration helpers for centralized LoRA finetuning.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import asdict, dataclass, field |
| from types import SimpleNamespace |
| from typing import Any, Dict, List, Optional |
|
|
|
|
| @dataclass |
| class LoRATuneConfig: |
| """Structured config matching the current loratune.py CLI surface.""" |
|
|
| base_model: str = "" |
| output_dir: str = "" |
| device: str = "cuda" |
| dtype: str = "bfloat16" |
| trust_remote_code: bool = False |
| seed: int = 42 |
|
|
| instruction_dataset: str = "tatsu-lab/alpaca" |
| instruction_config: Optional[str] = None |
| instruction_split: str = "train" |
| instruction_field_instruction: str = "instruction" |
| instruction_field_input: str = "input" |
| instruction_field_output: str = "output" |
| max_samples: int = 0 |
| seq_len: int = 1024 |
| batch_size: int = 64 |
| micro_batch_size: int = 4 |
| epochs: float = 1.0 |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.0 |
| max_grad_norm: float = 1.0 |
| log_steps: int = 100 |
|
|
| wikitext2_ppl_on_log: bool = True |
| wikitext2_ppl_seq_len: int = 128 |
| wikitext2_ppl_batch_size: int = 8 |
| wikitext2_ppl_max_batches: Optional[int] = None |
|
|
| lora_rank: int = 8 |
| lora_alpha: float = 16.0 |
| lora_dropout: float = 0.0 |
| lora_target_modules: List[str] = field( |
| default_factory=lambda: [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "down_proj", |
| "up_proj", |
| ] |
| ) |
|
|
| @property |
| def grad_accum_steps(self) -> int: |
| if self.batch_size < 1: |
| raise ValueError("batch_size must be >= 1") |
| if self.micro_batch_size < 1: |
| raise ValueError("micro_batch_size must be >= 1") |
| if self.batch_size < self.micro_batch_size: |
| raise ValueError("batch_size must be >= micro_batch_size") |
| return self.batch_size // self.micro_batch_size |
|
|
| def validate(self) -> "LoRATuneConfig": |
| _ = self.grad_accum_steps |
| if not self.base_model: |
| raise ValueError("base_model must be set") |
| if not self.output_dir: |
| raise ValueError("output_dir must be set") |
| return self |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| data = asdict(self) |
| data["grad_accum_steps"] = self.grad_accum_steps |
| return data |
|
|
| def to_namespace(self) -> SimpleNamespace: |
| return SimpleNamespace(**self.to_dict()) |
|
|
| @classmethod |
| def from_dict(cls, values: Dict[str, Any]) -> "LoRATuneConfig": |
| return cls(**values) |
|
|