# train_qlora.py # QLoRA fine tuning for chat JSONL built from attack plans # Works well with deepseek-ai/deepseek-coder-6.7b-instruct on Colab Pro GPUs from __future__ import annotations import argparse from pathlib import Path from typing import Dict, List, Union import torch from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments, ) from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--base", type=str, required=True, help="Base model id or path") ap.add_argument("--data", type=str, required=True, help="JSONL with chat messages") ap.add_argument("--out", type=str, required=True, help="Output dir for adapter") ap.add_argument("--epochs", type=int, default=2) ap.add_argument("--bsz", type=int, default=8) ap.add_argument("--grad_accum", type=int, default=1) ap.add_argument("--cutoff_len", type=int, default=2048) ap.add_argument("--lr", type=float, default=2e-4) ap.add_argument("--lora_r", type=int, default=16) ap.add_argument("--lora_alpha", type=int, default=32) ap.add_argument("--lora_dropout", type=float, default=0.05) ap.add_argument("--debug", action="store_true") return ap.parse_args() def device_supports_bf16() -> bool: if not torch.cuda.is_available(): return False major, _ = torch.cuda.get_device_capability(0) return major >= 8 # Ampere or newer def build_tokenizer(base_id: str): tok = AutoTokenizer.from_pretrained(base_id, use_fast=True) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.padding_side = "right" return tok def _to_ids(x: Union[torch.Tensor, List[int], Dict[str, List[int]]]) -> List[int]: if isinstance(x, torch.Tensor): return x.detach().cpu().tolist()[0] if x.ndim == 2 else x.detach().cpu().tolist() if isinstance(x, dict) and "input_ids" in x: return x["input_ids"] if isinstance(x, (list, tuple)): return list(x) raise TypeError(f"Unsupported chat template return type: {type(x)}") def chat_to_ids(tokenizer: AutoTokenizer, messages: List[Dict], max_len: int): # Prefer native chat template. In recent Transformers this returns a tensor # when return_tensors is set, or a list of token ids when tokenize is True. if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: out = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=False, return_tensors="pt", max_length=max_len, truncation=True, ) ids = _to_ids(out) attn = [1] * len(ids) return {"input_ids": ids, "attention_mask": attn} # Fallback when no chat template is available lines = [] for m in messages: role = m.get("role", "user") content = m.get("content", "") lines.append(f"{role}:\n{content}\n") text = "\n".join(lines) enc = tokenizer(text, max_length=max_len, truncation=True) return {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} def collate_pad(tokenizer: AutoTokenizer): pad_id = tokenizer.pad_token_id def _fn(batch: List[Dict[str, List[int]]]): max_len = max(len(x["input_ids"]) for x in batch) input_ids, attn, labels = [], [], [] for x in batch: ids = x["input_ids"] am = x["attention_mask"] pad_n = max_len - len(ids) input_ids.append(ids + [pad_id] * pad_n) attn.append(am + [0] * pad_n) labels.append(ids + [-100] * pad_n) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attn, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } return _fn def guess_lora_targets(model: torch.nn.Module) -> List[str]: prefs = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "wi", "wo", "w1", "w2", "w3", "out_proj", ] found = set() for name, _ in model.named_modules(): for p in prefs: if p in name: found.add(p) return sorted(found) if found else ["Linear"] def main(): args = parse_args() base_id = args.base data_path = Path(args.data) out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) tokenizer = build_tokenizer(base_id) ds = load_dataset("json", data_files=str(data_path), split="train") def map_row(ex): return chat_to_ids(tokenizer, ex["messages"], args.cutoff_len) # Remove original columns after mapping so only model fields remain ds = ds.map(map_row, remove_columns=ds.column_names) collate = collate_pad(tokenizer) quant = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) use_bf16 = device_supports_bf16() torch_dtype = torch.bfloat16 if use_bf16 else torch.float16 torch.backends.cuda.matmul.allow_tf32 = True model = AutoModelForCausalLM.from_pretrained( base_id, device_map="auto", quantization_config=quant, torch_dtype=torch_dtype, ) model = prepare_model_for_kbit_training(model) lconf = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=guess_lora_targets(model), ) model = get_peft_model(model, lconf) train_args = TrainingArguments( output_dir=str(out_dir), num_train_epochs=args.epochs, per_device_train_batch_size=args.bsz, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, lr_scheduler_type="cosine", warmup_ratio=0.03, logging_steps=5, save_steps=100, bf16=use_bf16, fp16=not use_bf16, optim="paged_adamw_8bit", remove_unused_columns=False, dataloader_num_workers=2, report_to=[], ) tr = Trainer( model=model, args=train_args, train_dataset=ds, data_collator=collate, tokenizer=tokenizer, ) if args.debug: batch = next(iter(tr.get_train_dataloader())) print("[debug] batch keys:", list(batch.keys())) for k, v in batch.items(): if isinstance(v, torch.Tensor): print(f"[debug] {k}: shape={tuple(v.shape)} dtype={v.dtype}") tr.train() model.save_pretrained(str(out_dir)) tokenizer.save_pretrained(str(out_dir)) print("[ok] saved adapter to", out_dir.resolve()) if __name__ == "__main__": main()