temp_ss / src /loratune_config.py
LJYAI's picture
upload src
2c44909 verified
#!/usr/bin/env python3
"""Configuration helpers for centralized LoRA finetuning."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
@dataclass
class LoRATuneConfig:
"""Structured config matching the current loratune.py CLI surface."""
base_model: str = ""
output_dir: str = ""
device: str = "cuda"
dtype: str = "bfloat16"
trust_remote_code: bool = False
seed: int = 42
instruction_dataset: str = "tatsu-lab/alpaca"
instruction_config: Optional[str] = None
instruction_split: str = "train"
instruction_field_instruction: str = "instruction"
instruction_field_input: str = "input"
instruction_field_output: str = "output"
max_samples: int = 0
seq_len: int = 1024
batch_size: int = 64
micro_batch_size: int = 4
epochs: float = 1.0
learning_rate: float = 1e-4
weight_decay: float = 0.0
max_grad_norm: float = 1.0
log_steps: int = 100
wikitext2_ppl_on_log: bool = True
wikitext2_ppl_seq_len: int = 128
wikitext2_ppl_batch_size: int = 8
wikitext2_ppl_max_batches: Optional[int] = None
lora_rank: int = 8
lora_alpha: float = 16.0
lora_dropout: float = 0.0
lora_target_modules: List[str] = field(
default_factory=lambda: [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"down_proj",
"up_proj",
]
)
@property
def grad_accum_steps(self) -> int:
if self.batch_size < 1:
raise ValueError("batch_size must be >= 1")
if self.micro_batch_size < 1:
raise ValueError("micro_batch_size must be >= 1")
if self.batch_size < self.micro_batch_size:
raise ValueError("batch_size must be >= micro_batch_size")
return self.batch_size // self.micro_batch_size
def validate(self) -> "LoRATuneConfig":
_ = self.grad_accum_steps
if not self.base_model:
raise ValueError("base_model must be set")
if not self.output_dir:
raise ValueError("output_dir must be set")
return self
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data["grad_accum_steps"] = self.grad_accum_steps
return data
def to_namespace(self) -> SimpleNamespace:
return SimpleNamespace(**self.to_dict())
@classmethod
def from_dict(cls, values: Dict[str, Any]) -> "LoRATuneConfig":
return cls(**values)