| """ |
| Component 9: LoRA fine-tuning pipeline for custom prompt->code pairs. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Tuple |
|
|
| import torch |
| import yaml |
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader, random_split |
| from tqdm import tqdm |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.finetuning_system.custom_pair_dataset import CustomPairDataset |
| from src.finetuning_system.lora_adapter import LoRAConfig, apply_lora, load_lora_state_dict, lora_state_dict |
| from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets |
| from src.training_pipeline.tokenized_dataset import CausalCollator |
| from src.tokenizer.code_tokenizer import CodeTokenizer |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run Component 9 LoRA fine-tuning.") |
| parser.add_argument("--config", default="configs/component9_lora_config.yaml") |
| return parser.parse_args() |
|
|
|
|
| def load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Config not found: {path}") |
| data = yaml.safe_load(path.read_text(encoding="utf-8-sig")) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML format.") |
| return data |
|
|
|
|
| def build_model_config(path: Path) -> ModelConfig: |
| cfg = load_yaml(path) |
| preset = cfg.get("preset") |
| model_cfg = cfg.get("model", {}) |
| if preset: |
| merged = get_model_presets()[preset].__dict__.copy() |
| merged.update(model_cfg) |
| return ModelConfig(**merged) |
| return ModelConfig(**model_cfg) |
|
|
|
|
| def get_vram_gb() -> float: |
| if not torch.cuda.is_available(): |
| return 0.0 |
| return torch.cuda.memory_allocated() / (1024**3) |
|
|
|
|
| def save_lora_ckpt(path: Path, step: int, lora_state: dict, optim_state: dict, best_val: float, no_improve: int) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| payload = { |
| "step": step, |
| "lora_state": lora_state, |
| "optimizer_state": optim_state, |
| "best_val": best_val, |
| "no_improve": no_improve, |
| } |
| torch.save(payload, path) |
|
|
|
|
| @torch.no_grad() |
| def eval_loss(model: CodeTransformerLM, loader: DataLoader, device: torch.device, use_fp16: bool) -> float: |
| model.eval() |
| vals = [] |
| for input_ids, labels in loader: |
| input_ids = input_ids.to(device) |
| labels = labels.to(device) |
| with torch.amp.autocast("cuda", enabled=(use_fp16 and device.type == "cuda"), dtype=torch.float16): |
| out = model(input_ids=input_ids, labels=labels) |
| vals.append(float(out["loss"].item())) |
| model.train() |
| if not vals: |
| return 1e9 |
| return sum(vals) / len(vals) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| try: |
| cfg = load_yaml(PROJECT_ROOT / args.config) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type != "cuda": |
| raise RuntimeError("CUDA GPU is required for LoRA fine-tuning.") |
|
|
| model_cfg = build_model_config(PROJECT_ROOT / cfg["model"]["model_config_path"]) |
| model = CodeTransformerLM(model_cfg).to(device) |
|
|
| base_ckpt = torch.load(PROJECT_ROOT / cfg["model"]["base_checkpoint_path"], map_location=device) |
| model.load_state_dict(base_ckpt["model_state"]) |
|
|
| lcfg = LoRAConfig( |
| r=int(cfg["lora"].get("r", 8)), |
| alpha=int(cfg["lora"].get("alpha", 16)), |
| dropout=float(cfg["lora"].get("dropout", 0.05)), |
| target_keywords=list(cfg["lora"].get("target_keywords", ["q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"])), |
| ) |
| replaced = apply_lora(model, lcfg) |
| if not replaced: |
| raise RuntimeError("No modules were LoRA-wrapped. Check target_keywords.") |
| |
| model = model.to(device) |
|
|
| tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / cfg["model"]["tokenizer_dir"])) |
| ds = CustomPairDataset( |
| path=str(PROJECT_ROOT / cfg["finetune"]["custom_data_path"]), |
| tokenizer=tokenizer, |
| max_seq_len=int(cfg["finetune"].get("max_seq_len", 512)), |
| ) |
|
|
| n_val = max(1, int(0.1 * len(ds))) |
| n_train = len(ds) - n_val |
| train_ds, val_ds = random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(17)) |
|
|
| collator = CausalCollator(pad_token_id=0, max_seq_len=int(cfg["finetune"].get("max_seq_len", 512))) |
| train_loader = DataLoader(train_ds, batch_size=int(cfg["finetune"].get("micro_batch_size", 1)), shuffle=True, collate_fn=collator) |
| val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, collate_fn=collator) |
|
|
| trainable = [p for p in model.parameters() if p.requires_grad] |
| optimizer = AdamW(trainable, lr=float(cfg["finetune"].get("learning_rate", 3e-4)), weight_decay=float(cfg["finetune"].get("weight_decay", 0.0))) |
|
|
| use_fp16 = bool(cfg["finetune"].get("use_fp16", True)) |
| scaler = torch.amp.GradScaler("cuda", enabled=use_fp16) |
|
|
| out_dir = PROJECT_ROOT / cfg["finetune"]["output_dir"] |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| max_steps = int(cfg["finetune"].get("max_steps", 1200)) |
| save_every = int(cfg["finetune"].get("save_every", 100)) |
| eval_every = int(cfg["finetune"].get("eval_every", 100)) |
| grad_accum = int(cfg["finetune"].get("grad_accum_steps", 16)) |
| max_vram = float(cfg["finetune"].get("max_vram_gb", 7.0)) |
| patience = int(cfg["finetune"].get("early_stopping_patience_evals", 6)) |
| min_delta = float(cfg["finetune"].get("early_stopping_min_delta", 5e-4)) |
|
|
| step = 0 |
| best_val = 1e9 |
| no_improve = 0 |
|
|
| resume_from = str(cfg.get("resume", {}).get("resume_from", "none")) |
| if resume_from != "none": |
| ckpt = out_dir / "latest.pt" if resume_from == "latest" else Path(resume_from) |
| if ckpt.exists(): |
| payload = torch.load(ckpt, map_location=device) |
| load_lora_state_dict(model, payload["lora_state"]) |
| optimizer.load_state_dict(payload["optimizer_state"]) |
| step = int(payload.get("step", 0)) |
| best_val = float(payload.get("best_val", 1e9)) |
| no_improve = int(payload.get("no_improve", 0)) |
| print(f"[resume] loaded {ckpt} at step {step}") |
|
|
| model.train() |
| pbar = tqdm(total=max_steps, initial=step, desc="lora_finetune", dynamic_ncols=True) |
| running = 0 |
|
|
| while step < max_steps: |
| for input_ids, labels in train_loader: |
| if step >= max_steps: |
| break |
| input_ids = input_ids.to(device) |
| labels = labels.to(device) |
|
|
| with torch.amp.autocast("cuda", enabled=use_fp16, dtype=torch.float16): |
| out = model(input_ids=input_ids, labels=labels) |
| loss = out["loss"] / grad_accum |
|
|
| scaler.scale(loss).backward() |
| running += 1 |
|
|
| if running % grad_accum == 0: |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| step += 1 |
| pbar.update(1) |
| pbar.set_postfix({"loss": f"{float(loss.item())*grad_accum:.4f}", "vram": f"{get_vram_gb():.2f}GB"}) |
|
|
| if get_vram_gb() > max_vram: |
| raise RuntimeError(f"VRAM threshold exceeded: {get_vram_gb():.2f}GB > {max_vram:.2f}GB") |
|
|
| if step % save_every == 0: |
| ck = out_dir / f"step_{step}.pt" |
| save_lora_ckpt(ck, step, lora_state_dict(model), optimizer.state_dict(), best_val, no_improve) |
| save_lora_ckpt(out_dir / "latest.pt", step, lora_state_dict(model), optimizer.state_dict(), best_val, no_improve) |
| print(f"\n[checkpoint] saved {ck}") |
|
|
| if step % eval_every == 0: |
| val = eval_loss(model, val_loader, device, use_fp16=use_fp16) |
| print(f"\n[eval] step={step} val_loss={val:.4f} best={best_val:.4f}") |
| if val < (best_val - min_delta): |
| best_val = val |
| no_improve = 0 |
| save_lora_ckpt(out_dir / "best.pt", step, lora_state_dict(model), optimizer.state_dict(), best_val, no_improve) |
| else: |
| no_improve += 1 |
| if no_improve >= patience: |
| print("\n[early_stop] no improvement, stopping.") |
| step = max_steps |
| break |
|
|
| pbar.close() |
| save_lora_ckpt(out_dir / "latest.pt", step, lora_state_dict(model), optimizer.state_dict(), best_val, no_improve) |
|
|
| |
| meta = { |
| "step": step, |
| "best_val": best_val, |
| "lora_config": { |
| "r": lcfg.r, |
| "alpha": lcfg.alpha, |
| "dropout": lcfg.dropout, |
| "target_keywords": lcfg.target_keywords, |
| }, |
| "base_checkpoint_path": cfg["model"]["base_checkpoint_path"], |
| } |
| (out_dir / "adapter_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8-sig") |
|
|
| print("Component 9 LoRA fine-tuning completed.") |
| print(f"LoRA adapters saved in: {out_dir}") |
|
|
| except Exception as exc: |
| print("Component 9 LoRA fine-tuning failed.") |
| print(f"What went wrong: {exc}") |
| print("Fix suggestion: verify custom data file format and checkpoint paths.") |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
|
|