| """ |
| Knowledge distillation training script. |
| |
| Trains a compressed Q-TensorFormer student using a dense teacher model. |
| Matches the student's parameter budget to ~50% of the teacher. |
| |
| Usage: |
| python scripts/distill.py --teacher_config small --student_rank 4 |
| """ |
|
|
| import sys |
| import os |
| import argparse |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| import torch |
| from src.config import ExperimentConfig, PRESETS |
| from src.models import create_model |
| from src.baselines import StandardTransformer |
| from src.data import load_wikitext2, load_synthetic_data |
| from src.training import DistillationTrainer |
| from src.metrics import evaluate_model |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="KD for Q-TensorFormer") |
| parser.add_argument("--teacher_config", type=str, default="small") |
| parser.add_argument("--student_rank", type=int, default=4) |
| parser.add_argument("--alpha", type=float, default=0.5, |
| help="Distillation loss weight") |
| parser.add_argument("--temperature", type=float, default=3.0) |
| parser.add_argument("--epochs", type=int, default=8) |
| parser.add_argument("--batch_size", type=int, default=16) |
| parser.add_argument("--device", type=str, default="cpu") |
| parser.add_argument("--output", type=str, default="./outputs/distill/") |
| parser.add_argument("--synthetic", action="store_true") |
| args = parser.parse_args() |
|
|
| torch.manual_seed(42) |
|
|
| |
| teacher_config = PRESETS[args.teacher_config]() |
| print(f"Teacher config: {teacher_config.experiment_name}") |
|
|
| |
| if args.synthetic: |
| train_loader = load_synthetic_data(batch_size=args.batch_size) |
| test_loader = train_loader |
| else: |
| train_loader, val_loader, test_loader, tokenizer = load_wikitext2( |
| batch_size=args.batch_size |
| ) |
| teacher_config.model.vocab_size = tokenizer.vocab_size |
|
|
| |
| teacher = StandardTransformer( |
| vocab_size=teacher_config.model.vocab_size, |
| d_model=teacher_config.model.d_model, |
| n_heads=teacher_config.model.n_heads, |
| n_layers=teacher_config.model.n_layers, |
| ) |
| print(f"Teacher params: {teacher.total_params:,}") |
|
|
| |
| student_config = ExperimentConfig( |
| model=type(teacher_config.model)( |
| **{k: v for k, v in teacher_config.model.__dict__.items()} |
| ), |
| training=type(teacher_config.training)( |
| **{k: v for k, v in teacher_config.training.__dict__.items()} |
| ), |
| ) |
| student_config.model.tt_rank = args.student_rank |
| student_config.model.use_quantum = True |
| student_config.training.max_epochs = args.epochs |
|
|
| student = create_model(student_config, "qtensor") |
| print(f"Student params: {student.total_params:,}") |
| print(f"Compression: {teacher.total_params / student.total_params:.1f}x") |
|
|
| |
| trainer = DistillationTrainer( |
| student=student, |
| teacher=teacher, |
| config=student_config, |
| train_loader=train_loader, |
| val_loader=val_loader if not args.synthetic else None, |
| test_loader=test_loader, |
| device=args.device, |
| output_dir=args.output, |
| alpha=args.alpha, |
| temperature=args.temperature, |
| ) |
| trainer.train() |
|
|
| |
| print("\nEvaluating knowledge-distilled model...") |
| results = evaluate_model(student, test_loader, args.device) |
| print(f"Student PPL: {results['test_ppl']:.2f}") |
| print(f"Student params: {results['total_params']:,}") |
| print(f"Compression vs teacher: {teacher.total_params / results['total_params']:.1f}x") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|