| """ |
| Train a custom DFlash drafter for any MLX-converted model. |
| |
| This example shows how to: |
| 1. Create a generic DFlash drafter for your model |
| 2. Generate training data using the target model |
| 3. Train the drafter with the DFlash training recipe |
| 4. Save and use the trained drafter |
| |
| Usage: |
| python train_custom_drafter.py \ |
| --model mlx-community/Llama-3.1-8B-Instruct-4bit \ |
| --output ./my-dflash-drafter \ |
| --dataset open-web-math \ |
| --samples 10000 |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| from mlx_lm import load |
| from dflash_mlx.universal import UniversalDFlashDecoder |
| from dflash_mlx.data import generate_training_data, create_mixed_training_data |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train custom DFlash drafter") |
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="MLX target model ID (e.g., mlx-community/Llama-3.1-8B-Instruct-4bit)", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| required=True, |
| help="Output directory for trained drafter", |
| ) |
| parser.add_argument( |
| "--dataset", |
| type=str, |
| default="open-web-math", |
| help="Dataset name or path for training data", |
| ) |
| parser.add_argument( |
| "--samples", |
| type=int, |
| default=10000, |
| help="Number of training samples to generate", |
| ) |
| parser.add_argument( |
| "--epochs", |
| type=int, |
| default=6, |
| help="Training epochs", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=8, |
| help="Training batch size", |
| ) |
| parser.add_argument( |
| "--lr", |
| type=float, |
| default=6e-4, |
| help="Learning rate", |
| ) |
| parser.add_argument( |
| "--draft-layers", |
| type=int, |
| default=5, |
| help="Number of draft model layers", |
| ) |
| parser.add_argument( |
| "--draft-hidden-size", |
| type=int, |
| default=1024, |
| help="Draft model hidden size", |
| ) |
| parser.add_argument( |
| "--block-size", |
| type=int, |
| default=16, |
| help="DFlash block size", |
| ) |
| parser.add_argument( |
| "--generate-data", |
| action="store_true", |
| help="Generate training data with target model first", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| output_path = Path(args.output) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| print(f"\n[1] Loading target model: {args.model}") |
| model, tokenizer = load(args.model) |
| print(" ✓ Target model loaded") |
|
|
| |
| print(f"\n[2] Creating DFlash decoder with generic drafter") |
| print(f" Draft layers: {args.draft_layers}, Hidden size: {args.draft_hidden_size}") |
| decoder = UniversalDFlashDecoder( |
| target_model=model, |
| tokenizer=tokenizer, |
| draft_layers=args.draft_layers, |
| draft_hidden_size=args.draft_hidden_size, |
| block_size=args.block_size, |
| ) |
| print(" ✓ Decoder initialized") |
|
|
| |
| data_path = output_path / "training_data.jsonl" |
| |
| if args.generate_data or not data_path.exists(): |
| print(f"\n[3] Generating training data...") |
| if args.dataset == "mixed": |
| create_mixed_training_data( |
| output_path=str(data_path), |
| total_samples=args.samples, |
| ) |
| else: |
| generate_training_data( |
| target_model=model, |
| tokenizer=tokenizer, |
| prompts_dataset=args.dataset, |
| output_path=str(data_path), |
| num_samples=args.samples, |
| temperature=0.0, |
| ) |
| else: |
| print(f"\n[3] Using existing training data: {data_path}") |
|
|
| |
| print(f"\n[4] Training DFlash drafter...") |
| print(f" Epochs: {args.epochs}, Batch size: {args.batch_size}, LR: {args.lr}") |
| |
| trained_drafter = decoder.train_drafter( |
| dataset=str(data_path), |
| epochs=args.epochs, |
| batch_size=args.batch_size, |
| lr=args.lr, |
| output_path=str(output_path / "drafter"), |
| ) |
|
|
| |
| print(f"\n[5] Saving trained drafter...") |
| decoder.save_drafter(str(output_path / "drafter")) |
| |
| |
| import json |
| metadata = { |
| "target_model": args.model, |
| "draft_layers": args.draft_layers, |
| "draft_hidden_size": args.draft_hidden_size, |
| "block_size": args.block_size, |
| "training_epochs": args.epochs, |
| "training_samples": args.samples, |
| "learning_rate": args.lr, |
| } |
| with open(output_path / "metadata.json", "w") as f: |
| json.dump(metadata, f, indent=2) |
|
|
| print(f"\n{'='*60}") |
| print("Training complete!") |
| print(f"{'='*60}") |
| print(f"\nTo use your trained drafter:") |
| print(f" from dflash_mlx.universal import UniversalDFlashDecoder") |
| print(f" from mlx_lm import load") |
| print(f" model, tokenizer = load('{args.model}')") |
| print(f" decoder = UniversalDFlashDecoder(") |
| print(f" target_model=model,") |
| print(f" tokenizer=tokenizer,") |
| print(f" draft_model_path='{output_path / 'drafter'}',") |
| print(f" )") |
| print(f" output = decoder.generate('Your prompt here')") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|