| |
| |
| """DFlash LoRA Training Script with Direct Layer Injection. |
| |
| Trains Qwen3-8B with LoRA adapters using direct layer-by-layer injection from target model. |
| |
| Key features: |
| - Target model extracts hidden states at each layer |
| - Draft model (same structure + LoRA) receives these hidden states layer-by-layer |
| - No feature extraction layers (fc + hidden_norm) - direct injection |
| - Only LoRA parameters are trained; base model is frozen |
| - Saves LoRA adapter weights only |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import math |
| import os |
| import time |
| import warnings |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from accelerate.utils import set_seed |
|
|
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| from datasets import load_dataset |
| from specforge.args import TrackerArgs |
| from specforge.core.dflash_lora_inject import OnlineDFlashLoRAInjectModel |
| from specforge.data import build_eagle3_dataset, prepare_dp_dataloaders |
| from specforge.distributed import destroy_distributed, get_dp_group, init_distributed |
| from specforge.modeling.draft.dflash_lora_inject import DFlashLoRAInjectDraftModel |
| from specforge.modeling.target.dflash_target_model import get_dflash_target_model |
| from specforge.optimizer import BF16Optimizer |
| from specforge.tracker import create_tracker |
| from specforge.utils import get_last_checkpoint, print_on_rank0, print_with_rank |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train DFlash LoRA with Direct Layer Injection") |
|
|
| model_group = parser.add_argument_group("model") |
| model_group.add_argument("--target-model-path", type=str, required=True, |
| help="Path to target model (for extracting hidden states)") |
| model_group.add_argument("--target-model-backend", type=str, default="hf", |
| choices=["hf", "sglang"], |
| help="Backend for target model") |
| model_group.add_argument("--draft-model-path", type=str, default=None, |
| help="Path to draft model base (default: same as target)") |
| model_group.add_argument("--block-size", type=int, default=16) |
| model_group.add_argument("--mask-token-id", type=int, default=None, |
| help="MASK token ID. Auto-detected from tokenizer if not set.") |
| model_group.add_argument("--context-len", type=int, default=0, |
| help="Fixed context length before blocks. 0 = treat whole seq as blocks.") |
| model_group.add_argument("--trust-remote-code", action="store_true") |
| model_group.add_argument("--attn-implementation", type=str, default="sdpa", |
| choices=["sdpa", "eager"], |
| help="Attention backend for additive mask path.") |
| model_group.add_argument("--attention-backend", type=str, default="flex_attention", |
| choices=["flex_attention", "additive"], |
| help="flex_attention: use BlockMask. additive: use 4D mask.") |
|
|
| lora_group = parser.add_argument_group("lora") |
| lora_group.add_argument("--lora-rank", type=int, default=16) |
| lora_group.add_argument("--lora-alpha", type=int, default=32) |
| lora_group.add_argument("--lora-dropout", type=float, default=0.05) |
| lora_group.add_argument("--lora-target-modules", type=str, nargs="+", |
| default=["q_proj", "k_proj", "v_proj", "o_proj"], |
| help="Which modules to apply LoRA to") |
| lora_group.add_argument("--lora-config", type=str, default=None, |
| help="Path to JSON file with LoRA config") |
|
|
| dataset_group = parser.add_argument_group("dataset") |
| dataset_group.add_argument("--train-data-path", type=str, required=True) |
| dataset_group.add_argument("--eval-data-path", type=str, default=None) |
| dataset_group.add_argument("--chat-template", type=str, default="qwen") |
| dataset_group.add_argument("--is-preformatted", action="store_true") |
| dataset_group.add_argument("--dataloader-num-workers", type=int, default=8) |
| dataset_group.add_argument("--build-dataset-num-proc", type=int, |
| default=int(os.environ.get("SPECFORGE_DATA_NUM_PROC", 8))) |
|
|
| training_group = parser.add_argument_group("training") |
| training_group.add_argument("--num-epochs", type=int, default=3) |
| training_group.add_argument("--batch-size", type=int, default=1) |
| training_group.add_argument("--learning-rate", type=float, default=2e-4) |
| training_group.add_argument("--max-length", type=int, default=2048) |
| training_group.add_argument("--warmup-ratio", type=float, default=0.04) |
| training_group.add_argument("--max-grad-norm", type=float, default=1.0) |
| training_group.add_argument("--accumulation-steps", type=int, default=1) |
| training_group.add_argument("--loss-decay-gamma", type=float, default=None) |
| training_group.add_argument("--optimizer-type", type=str, default="adamw", |
| choices=["adamw", "adamw_8bit"]) |
| training_group.add_argument("--no-fp32-params", action="store_true") |
| training_group.add_argument("--gradient-checkpointing", action="store_true") |
| training_group.add_argument("--seed", type=int, default=42) |
| training_group.add_argument("--resume", action="store_true") |
| training_group.add_argument("--ckpt-dir", type=str, default=None) |
|
|
| output_group = parser.add_argument_group("output") |
| output_group.add_argument("--output-dir", type=str, required=True) |
| output_group.add_argument("--cache-dir", type=str, default="./cache") |
| output_group.add_argument("--log-interval", type=int, default=50) |
| output_group.add_argument("--eval-interval", type=int, default=1000) |
| output_group.add_argument("--save-interval", type=int, default=1000) |
|
|
| early_stop_group = parser.add_argument_group("early stopping") |
| early_stop_group.add_argument("--early-stop", action="store_true", |
| help="Enable early stopping based on training accuracy") |
| early_stop_group.add_argument("--early-stop-patience", type=int, default=5, |
| help="Stop after N consecutive log intervals without improvement (default: 5)") |
| early_stop_group.add_argument("--early-stop-min-delta", type=float, default=0.005, |
| help="Minimum accuracy improvement to reset patience (default: 0.005)") |
| early_stop_group.add_argument("--early-stop-acc-threshold", type=float, default=None, |
| help="Hard accuracy threshold — stop immediately when reached (default: None)") |
| early_stop_group.add_argument("--early-stop-warmup-steps", type=int, default=0, |
| help="Number of log intervals to skip before enabling early stopping (default: 0)") |
| early_stop_group.add_argument("--early-stop-relative-delta", action="store_true", |
| help="Treat min_delta as a fraction of best_acc instead of absolute value") |
|
|
| dist_group = parser.add_argument_group("distributed") |
| dist_group.add_argument("--dist-timeout", type=int, default=30) |
|
|
| tracker_group = parser.add_argument_group("tracker") |
| TrackerArgs.add_args(tracker_group) |
|
|
| return parser.parse_args() |
|
|
|
|
| def build_model(args) -> Tuple[DFlashLoRAInjectDraftModel, OnlineDFlashLoRAInjectModel, any]: |
| """Load target model and draft model with LoRA.""" |
| print_on_rank0(f"Loading target model from {args.target_model_path}") |
|
|
| |
| target_model = get_dflash_target_model( |
| pretrained_model_name_or_path=args.target_model_path, |
| backend=args.target_model_backend, |
| torch_dtype=torch.bfloat16, |
| device="cuda", |
| cache_dir=args.cache_dir, |
| trust_remote_code=args.trust_remote_code, |
| ) |
|
|
| |
| |
| |
| if hasattr(target_model, 'set_capture_layers'): |
| |
| target_model.set_capture_layers(None) |
|
|
| print_on_rank0(f"Loading draft model from {args.draft_model_path or args.target_model_path}") |
|
|
| |
| lora_rank = args.lora_rank |
| lora_alpha = args.lora_alpha |
| lora_dropout = args.lora_dropout |
| lora_target_modules = args.lora_target_modules |
|
|
| if args.lora_config is not None: |
| with open(args.lora_config) as f: |
| lora_cfg = json.load(f) |
| lora_rank = lora_cfg.get("lora_rank", lora_rank) |
| lora_alpha = lora_cfg.get("lora_alpha", lora_alpha) |
| lora_dropout = lora_cfg.get("lora_dropout", lora_dropout) |
| lora_target_modules = lora_cfg.get("lora_target_modules", lora_target_modules) |
| print_on_rank0(f"Loaded LoRA config from {args.lora_config}") |
|
|
| |
| if args.attention_backend == "flex_attention": |
| attn_impl = "flex_attention" |
| else: |
| attn_impl = args.attn_implementation |
|
|
| |
| draft_model = DFlashLoRAInjectDraftModel.from_pretrained( |
| pretrained_model_name_or_path=args.draft_model_path or args.target_model_path, |
| lora_rank=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| lora_target_modules=lora_target_modules, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| torch_dtype=torch.bfloat16, |
| device_map="cuda", |
| trust_remote_code=args.trust_remote_code, |
| attn_implementation=attn_impl, |
| ) |
|
|
| |
| online_model = OnlineDFlashLoRAInjectModel( |
| draft_model=draft_model, |
| target_model=target_model, |
| block_size=args.block_size, |
| mask_token_id=args.mask_token_id or 151669, |
| loss_decay_gamma=args.loss_decay_gamma, |
| attention_backend=args.attention_backend, |
| lm_head_chunk_size=0, |
| random_anchor=False, |
| num_anchors=512, |
| ) |
|
|
| trainable = sum(p.numel() for p in draft_model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in draft_model.parameters()) |
| print_on_rank0(f"Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") |
|
|
| return draft_model, online_model, target_model |
|
|
|
|
| def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]: |
| """Build train and eval dataloaders.""" |
| import hashlib |
|
|
| cache_params_string = ( |
| f"{args.train_data_path}-{args.max_length}-{args.chat_template}-{args.target_model_path}" |
| ) |
| cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() |
|
|
| rank = dist.get_rank() |
|
|
| |
| if os.path.isdir(args.train_data_path): |
| train_dataset = load_dataset(args.train_data_path, split="train", verification_mode="no_checks") |
| else: |
| train_dataset = load_dataset("json", data_files=args.train_data_path)["train"] |
|
|
| dataset_kwargs = dict( |
| dataset=train_dataset, |
| tokenizer=tokenizer, |
| chat_template=args.chat_template, |
| max_length=args.max_length, |
| is_preformatted=args.is_preformatted, |
| cache_dir=os.path.join(args.cache_dir, "processed_dataset"), |
| cache_key=cache_key, |
| num_proc=args.build_dataset_num_proc, |
| ) |
|
|
| |
| if rank == 0: |
| train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) |
| dist.barrier() |
| if rank != 0: |
| train_eagle3_dataset = build_eagle3_dataset(**dataset_kwargs) |
|
|
| min_loss_tokens = 2 * args.block_size |
| original_size = len(train_eagle3_dataset) |
| train_eagle3_dataset = train_eagle3_dataset.filter( |
| lambda x: x["loss_mask"].sum() >= min_loss_tokens |
| ) |
| print_on_rank0(f"Filtered train dataset: {original_size} -> {len(train_eagle3_dataset)} samples") |
|
|
| train_dataloader = prepare_dp_dataloaders( |
| train_eagle3_dataset, |
| args.batch_size, |
| num_workers=args.dataloader_num_workers, |
| shuffle=True, |
| process_group=get_dp_group(), |
| ) |
|
|
| eval_dataloader = None |
| if args.eval_data_path: |
| if os.path.isdir(args.eval_data_path): |
| eval_dataset = load_dataset(args.eval_data_path, split="train") |
| else: |
| eval_dataset = load_dataset("json", data_files=args.eval_data_path)["train"] |
| eval_eagle3_dataset = build_eagle3_dataset( |
| dataset=eval_dataset, |
| tokenizer=tokenizer, |
| chat_template=args.chat_template, |
| max_length=args.max_length, |
| is_preformatted=args.is_preformatted, |
| ) |
| eval_dataloader = prepare_dp_dataloaders( |
| eval_eagle3_dataset, |
| args.batch_size, |
| num_workers=args.dataloader_num_workers, |
| shuffle=False, |
| process_group=get_dp_group(), |
| ) |
|
|
| return train_dataloader, eval_dataloader |
|
|
|
|
| def save_checkpoint(args, epoch, step, online_model, draft_model, optimizer): |
| """Save LoRA adapter weights + training state.""" |
| save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") |
| if dist.get_rank() == 0: |
| os.makedirs(save_dir, exist_ok=True) |
| dist.barrier() |
|
|
| if dist.get_rank() == 0: |
| |
| module = online_model.module if isinstance(online_model, DDP) else online_model |
| lora_state_dict = { |
| k: v for k, v in module.draft_model.model.state_dict().items() |
| if "lora_" in k |
| } |
|
|
| try: |
| from safetensors.torch import save_file as safetensors_save |
| safetensors_save(lora_state_dict, os.path.join(save_dir, "adapter_model.safetensors")) |
| except (ImportError, Exception): |
| torch.save(lora_state_dict, os.path.join(save_dir, "adapter_model.bin")) |
|
|
| draft_model.model.peft_config["default"].save_pretrained(save_dir) |
|
|
| torch.save( |
| { |
| "epoch": epoch, |
| "global_step": step, |
| "args": args, |
| **optimizer.state_dict(), |
| }, |
| os.path.join(save_dir, "training_state.pt"), |
| ) |
| print_on_rank0(f"Saved LoRA checkpoint to {save_dir}") |
|
|
| dist.barrier() |
|
|
|
|
| def record_metrics(args, loss, accuracy, global_step, tracker, optimizer, mode="train"): |
| logdict = {} |
| if mode == "train" and optimizer is not None: |
| logdict["train/lr"] = optimizer.get_learning_rate() |
| logdict[f"{mode}/loss"] = loss |
| logdict[f"{mode}/accuracy"] = accuracy |
| print_on_rank0( |
| f"{mode.capitalize()} - Step {global_step}, Loss: {loss:.4f}, Acc: {accuracy:.4f}" |
| ) |
| tracker.log(logdict, step=global_step) |
|
|
|
|
| class EarlyStopping: |
| """Monitor accuracy and signal when training should stop.""" |
|
|
| def __init__(self, patience: int, min_delta: float, acc_threshold: float = None, |
| warmup_steps: int = 0, relative_delta: bool = False): |
| self.patience = patience |
| self.min_delta = min_delta |
| self.acc_threshold = acc_threshold |
| self.warmup_steps = warmup_steps |
| self.relative_delta = relative_delta |
| self.best_acc = -1.0 |
| self.counter = 0 |
| self.num_calls = 0 |
|
|
| def should_stop(self, acc: float) -> bool: |
| self.num_calls += 1 |
| |
| if self.num_calls <= self.warmup_steps: |
| |
| if acc > self.best_acc: |
| self.best_acc = acc |
| return False |
| |
| if self.acc_threshold is not None and acc >= self.acc_threshold: |
| return True |
| |
| if self.relative_delta and self.best_acc > 0: |
| delta = self.min_delta * self.best_acc |
| else: |
| delta = self.min_delta |
| if acc > self.best_acc + delta: |
| self.best_acc = acc |
| self.counter = 0 |
| else: |
| self.counter += 1 |
| return self.counter >= self.patience |
|
|
|
|
| def main(): |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| warnings.filterwarnings( |
| "ignore", |
| "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed", |
| ) |
|
|
| mp.set_sharing_strategy('file_system') |
|
|
| args = parse_args() |
| set_seed(args.seed) |
|
|
| |
| init_distributed(timeout=args.dist_timeout, tp_size=1) |
| print_with_rank("Initialized distributed") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained(args.target_model_path) |
| if args.mask_token_id is not None: |
| mask_token_id = args.mask_token_id |
| elif tokenizer.mask_token_id is not None: |
| mask_token_id = tokenizer.mask_token_id |
| else: |
| tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) |
| mask_token_id = tokenizer.mask_token_id |
| print_on_rank0(f"Using mask_token_id: {mask_token_id}") |
| args.mask_token_id = mask_token_id |
|
|
| draft_model, online_model, target_model = build_model(args) |
|
|
| |
| draft_model.mask_token_id = mask_token_id |
| online_model.mask_token_id = mask_token_id |
|
|
| if args.gradient_checkpointing: |
| draft_model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| print_on_rank0("Gradient checkpointing enabled") |
|
|
| |
| resume_state = None |
| if args.ckpt_dir is not None: |
| if os.path.isdir(args.ckpt_dir): |
| print_on_rank0(f"Loading LoRA weights from {args.ckpt_dir}") |
| from peft import PeftModel |
| draft_model.model = PeftModel.from_pretrained( |
| draft_model.model.base_model.model, args.ckpt_dir |
| ) |
| else: |
| raise ValueError(f"ckpt_dir {args.ckpt_dir} is not a valid directory") |
|
|
| if args.resume and os.path.isdir(args.output_dir): |
| last_ckpt = get_last_checkpoint(args.output_dir, prefix=r"epoch_\d+_step") |
| if last_ckpt: |
| print_on_rank0(f"Resuming from {last_ckpt}") |
| from peft import PeftModel |
| draft_model.model = PeftModel.from_pretrained( |
| draft_model.model.base_model.model, last_ckpt |
| ) |
| training_state_path = os.path.join(last_ckpt, "training_state.pt") |
| if os.path.exists(training_state_path): |
| resume_state = torch.load(training_state_path, map_location="cpu", weights_only=False) |
| print_on_rank0( |
| f"Will resume from epoch {resume_state['epoch']}, step {resume_state['global_step']}" |
| ) |
|
|
| train_dataloader, eval_dataloader = build_dataloader(args, tokenizer) |
|
|
| steps_per_epoch = math.ceil(len(train_dataloader) / args.accumulation_steps) |
| total_steps = args.num_epochs * steps_per_epoch |
| print_on_rank0(f"Total training steps: {total_steps}") |
|
|
| |
| local_rank = int(os.environ.get("LOCAL_RANK", dist.get_rank() % torch.cuda.device_count())) |
| online_model = DDP(online_model, device_ids=[local_rank], find_unused_parameters=False) |
| print_with_rank("Initialized DDP") |
|
|
| optimizer = BF16Optimizer( |
| draft_model, |
| lr=args.learning_rate, |
| max_grad_norm=args.max_grad_norm, |
| warmup_ratio=args.warmup_ratio, |
| total_steps=total_steps, |
| use_fp32_params=not args.no_fp32_params, |
| optimizer_type=args.optimizer_type, |
| ) |
|
|
| start_epoch = 0 |
| global_step = 0 |
| if resume_state is not None: |
| optimizer.scheduler.load_state_dict(resume_state["scheduler_state_dict"]) |
| start_epoch = resume_state["epoch"] |
| global_step = resume_state["global_step"] |
| del resume_state |
| print_on_rank0(f"Restored scheduler, lr={optimizer.get_learning_rate():.6f}") |
|
|
| skip_steps = global_step - start_epoch * len(train_dataloader) |
|
|
| tracker = create_tracker(args, args.output_dir) |
| last_time = time.time() |
| print_on_rank0(f"Starting training from epoch {start_epoch}, step {global_step}") |
|
|
| |
| early_stopper = None |
| if args.early_stop: |
| early_stopper = EarlyStopping( |
| patience=args.early_stop_patience, |
| min_delta=args.early_stop_min_delta, |
| acc_threshold=args.early_stop_acc_threshold, |
| warmup_steps=args.early_stop_warmup_steps, |
| relative_delta=args.early_stop_relative_delta, |
| ) |
| print_on_rank0( |
| f"Early stopping enabled: patience={args.early_stop_patience}, " |
| f"min_delta={args.early_stop_min_delta}, " |
| f"relative_delta={args.early_stop_relative_delta}, " |
| f"warmup_steps={args.early_stop_warmup_steps}, " |
| f"acc_threshold={args.early_stop_acc_threshold}" |
| ) |
| should_stop = False |
|
|
| for epoch in range(start_epoch, args.num_epochs): |
| train_dataloader.sampler.set_epoch(epoch) |
| draft_model.train() |
|
|
| if dist.get_rank() == 0: |
| progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=True) |
| else: |
| progress_bar = train_dataloader |
|
|
| for step_in_epoch, data in enumerate(progress_bar): |
| global_step += 1 |
| if epoch == start_epoch and step_in_epoch < skip_steps: |
| continue |
|
|
| input_ids = data["input_ids"].cuda() |
| attention_mask = data["attention_mask"].cuda() |
| loss_mask = data["loss_mask"].cuda() |
|
|
| loss, accuracy = online_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| loss_mask=loss_mask, |
| context_len=args.context_len, |
| ) |
| (loss / args.accumulation_steps).backward() |
|
|
| if global_step % args.accumulation_steps == 0: |
| optimizer.step() |
|
|
| if global_step % args.log_interval == 0: |
| loss_val = loss.item() |
| acc_val = accuracy.item() |
| loss_t = torch.tensor(loss_val, device="cuda") |
| acc_t = torch.tensor(acc_val, device="cuda") |
| dist.all_reduce(loss_t) |
| dist.all_reduce(acc_t) |
| avg_acc = acc_t.item() / dist.get_world_size() |
| record_metrics(args, loss_t.item() / dist.get_world_size(), |
| avg_acc, global_step, |
| tracker, optimizer, mode="train") |
|
|
| |
| if early_stopper is not None: |
| stop_flag = torch.tensor(0, device="cuda") |
| if dist.get_rank() == 0: |
| if early_stopper.should_stop(avg_acc): |
| stop_flag.fill_(1) |
| print_on_rank0( |
| f"Early stopping triggered at step {global_step}, " |
| f"best_acc={early_stopper.best_acc:.4f}, " |
| f"patience={early_stopper.counter}/{early_stopper.patience}" |
| ) |
| dist.broadcast(stop_flag, src=0) |
| if stop_flag.item() == 1: |
| save_checkpoint(args, epoch, global_step, online_model, draft_model, optimizer) |
| should_stop = True |
| break |
|
|
| if dist.get_rank() == 0: |
| elapsed = time.time() - last_time |
| last_time = time.time() |
| progress_bar.set_postfix({ |
| "loss": f"{loss.item():.4f}", |
| "acc": f"{accuracy.item():.4f}", |
| "iter_time": f"{elapsed:.2f}s", |
| }) |
|
|
| if global_step % args.save_interval == 0: |
| save_checkpoint(args, epoch, global_step, online_model, draft_model, optimizer) |
|
|
| if should_stop: |
| break |
|
|
| save_checkpoint(args, args.num_epochs, global_step, online_model, draft_model, optimizer) |
| tracker.close() |
| destroy_distributed() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|