| """ |
| GIS-Coder 7B: Production QLoRA SFT Training Script |
| ==================================================== |
| Fine-tunes Qwen2.5-Coder-7B-Instruct for GIS code generation. |
| |
| Hardware requirements: |
| - Minimum: 1x A10G (24GB) or 1x RTX 4090 (24GB) |
| - Recommended: 1x A100 (80GB) for faster training + larger batch |
| - Also works on: H100, L40S, RTX 3090 |
| |
| Training recipe based on: |
| - CFD fine-tuning (arxiv:2504.09602): QLoRA, r=16, 88.7% accuracy on domain tasks |
| - MapCoder-Lite (arxiv:2509.17489): Qwen2.5-Coder-7B as best backbone for code LoRA |
| - LoRA Without Regret: target all-linear layers, lr=2e-4 for LoRA |
| |
| Usage: |
| # Single GPU |
| python train_7b.py |
| |
| # Multi-GPU with accelerate |
| accelerate launch --num_processes 2 train_7b.py |
| |
| # With custom settings |
| python train_7b.py --epochs 5 --lr 1e-4 --lora_r 32 --max_length 4096 |
| """ |
|
|
| import os |
| import argparse |
| import torch |
| from datasets import load_dataset |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from peft import LoraConfig, prepare_model_for_kbit_training |
| from trl import SFTConfig, SFTTrainer |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train GIS-Coder 7B") |
| parser.add_argument("--model_id", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct") |
| parser.add_argument("--dataset_id", type=str, default="RhodWeo/gis-code-instructions") |
| parser.add_argument("--hub_model_id", type=str, default="RhodWeo/GIS-Coder-7B") |
| parser.add_argument("--output_dir", type=str, default="./gis-coder-7b-output") |
| |
| |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate (2e-4 for LoRA)") |
| parser.add_argument("--batch_size", type=int, default=2, help="Per-device batch size") |
| parser.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps") |
| parser.add_argument("--max_length", type=int, default=4096, help="Max sequence length") |
| parser.add_argument("--warmup_ratio", type=float, default=0.1) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--scheduler", type=str, default="cosine") |
| |
| |
| parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank") |
| parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha") |
| parser.add_argument("--lora_dropout", type=float, default=0.05) |
| parser.add_argument("--target_modules", type=str, default="all-linear", |
| help="Target modules (all-linear or comma-separated list)") |
| |
| |
| parser.add_argument("--no_quantize", action="store_true", help="Disable 4-bit quantization (full fp16)") |
| parser.add_argument("--use_flash_attn", action="store_true", help="Use Flash Attention 2") |
| |
| |
| parser.add_argument("--use_trackio", action="store_true", help="Enable Trackio monitoring") |
| parser.add_argument("--trackio_project", type=str, default="gis-coder-7b") |
| |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| |
| |
| if args.use_trackio: |
| import trackio |
| trackio.init( |
| project=args.trackio_project, |
| config=vars(args), |
| ) |
| |
| |
| print(f"Loading dataset: {args.dataset_id}") |
| dataset = load_dataset(args.dataset_id, data_files="data/train.jsonl", split="train") |
| print(f" {len(dataset)} examples, columns: {dataset.column_names}") |
| |
| |
| print(f"Loading model: {args.model_id}") |
| |
| model_kwargs = { |
| "trust_remote_code": True, |
| "attn_implementation": "flash_attention_2" if args.use_flash_attn else "eager", |
| } |
| |
| if not args.no_quantize: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| model_kwargs["quantization_config"] = bnb_config |
| model_kwargs["dtype"] = torch.bfloat16 |
| else: |
| model_kwargs["dtype"] = torch.bfloat16 |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_id, |
| device_map="auto", |
| **model_kwargs, |
| ) |
| |
| tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| model.config.pad_token_id = tokenizer.eos_token_id |
| |
| if not args.no_quantize: |
| model = prepare_model_for_kbit_training(model) |
| |
| print(f" Parameters: {model.num_parameters()/1e9:.2f}B") |
| |
| |
| target = args.target_modules |
| if target != "all-linear": |
| target = target.split(",") |
| |
| peft_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| target_modules=target, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| print(f" LoRA: r={args.lora_r}, alpha={args.lora_alpha}, targets={target}") |
| |
| |
| training_args = SFTConfig( |
| output_dir=args.output_dir, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.grad_accum, |
| learning_rate=args.lr, |
| lr_scheduler_type=args.scheduler, |
| warmup_ratio=args.warmup_ratio, |
| weight_decay=args.weight_decay, |
| |
| gradient_checkpointing=True, |
| bf16=True, |
| max_length=args.max_length, |
| |
| logging_steps=1, |
| logging_first_step=True, |
| logging_strategy="steps", |
| disable_tqdm=True, |
| report_to="trackio" if args.use_trackio else "none", |
| |
| save_strategy="epoch", |
| save_total_limit=3, |
| |
| push_to_hub=True, |
| hub_model_id=args.hub_model_id, |
| hub_strategy="every_save", |
| |
| dataloader_num_workers=4, |
| seed=42, |
| ) |
| |
| |
| trainer = SFTTrainer( |
| model=model, |
| processing_class=tokenizer, |
| args=training_args, |
| train_dataset=dataset, |
| peft_config=peft_config, |
| ) |
| |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| total = sum(p.numel() for p in model.parameters()) |
| print(f" Trainable: {trainable:,} ({trainable/total*100:.2f}%)") |
| |
| |
| eff_bs = args.batch_size * args.grad_accum |
| print(f"\n{'='*60}") |
| print(f"TRAINING: {args.model_id}") |
| print(f" Dataset: {len(dataset)} examples") |
| print(f" Method: {'QLoRA' if not args.no_quantize else 'LoRA'} (r={args.lora_r})") |
| print(f" LR: {args.lr}, Epochs: {args.epochs}, Eff. batch: {eff_bs}") |
| print(f" Max length: {args.max_length}") |
| print(f" Push to: {args.hub_model_id}") |
| print(f"{'='*60}\n") |
| |
| result = trainer.train() |
| |
| |
| print("\nSaving final model...") |
| trainer.save_model(os.path.join(args.output_dir, "final")) |
| trainer.push_to_hub(commit_message="GIS-Coder 7B β final after training") |
| |
| m = result.metrics |
| print(f"\nDone! Loss: {m.get('train_loss','?')}, Time: {m.get('train_runtime',0):.0f}s") |
| print(f"Model: https://huggingface.co/{args.hub_model_id}") |
| |
| if args.use_trackio: |
| import trackio |
| trackio.log({"final_loss": m.get("train_loss", 0), "runtime": m.get("train_runtime", 0)}) |
| trackio.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|