|
|
| """
|
| Main training entry point for TouchGrass models.
|
| Fine-tunes Qwen3.5 with LoRA and music modules.
|
| """
|
|
|
| import argparse
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from peft import LoraConfig, get_peft_model, TaskType
|
|
|
| from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
|
| from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
|
| from configs.training_config import (
|
| TRAINING_CONFIG_3B_CUDA,
|
| TRAINING_CONFIG_7B_CUDA,
|
| TRAINING_CONFIG_MPS,
|
| )
|
| from data.dataset_loader import TouchGrassDataset
|
| from training.trainer import TouchGrassTrainer
|
| from tokenizer.music_token_extension import MusicTokenizerExtension
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Train TouchGrass music assistant model")
|
| parser.add_argument(
|
| "--model_size",
|
| type=str,
|
| choices=["3b", "7b"],
|
| default="3b",
|
| help="Model size to train",
|
| )
|
| parser.add_argument(
|
| "--device",
|
| type=str,
|
| default="cuda",
|
| choices=["cuda", "mps", "cpu"],
|
| help="Device to train on",
|
| )
|
| parser.add_argument(
|
| "--use_mps",
|
| action="store_true",
|
| help="Use MPS backend (Apple Silicon)",
|
| )
|
| parser.add_argument(
|
| "--data_dir",
|
| type=str,
|
| default="./data/processed",
|
| help="Directory with processed data shards",
|
| )
|
| parser.add_argument(
|
| "--output_dir",
|
| type=str,
|
| default="./checkpoints",
|
| help="Output directory for checkpoints",
|
| )
|
| parser.add_argument(
|
| "--max_steps",
|
| type=int,
|
| default=None,
|
| help="Override max training steps",
|
| )
|
| parser.add_argument(
|
| "--micro_batch_size",
|
| type=int,
|
| default=None,
|
| help="Override micro batch size",
|
| )
|
| parser.add_argument(
|
| "--lora_r",
|
| type=int,
|
| default=16,
|
| help="LoRA rank",
|
| )
|
| parser.add_argument(
|
| "--lora_alpha",
|
| type=int,
|
| default=32,
|
| help="LoRA alpha",
|
| )
|
| parser.add_argument(
|
| "--resume_from_checkpoint",
|
| type=str,
|
| default=None,
|
| help="Resume training from checkpoint",
|
| )
|
| parser.add_argument(
|
| "--generate_data",
|
| action="store_true",
|
| help="Generate synthetic training data before training",
|
| )
|
| parser.add_argument(
|
| "--num_train_samples",
|
| type=int,
|
| default=10000,
|
| help="Number of training samples to generate",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def load_tokenizer(config: dict, args):
|
| """Load and extend tokenizer with music tokens."""
|
| base_model = config["base_model"]
|
| print(f"Loading base tokenizer: {base_model}")
|
|
|
|
|
| tokenizer_ext = MusicTokenizerExtension(
|
| base_tokenizer_name=base_model,
|
| special_tokens=config.get("special_tokens"),
|
| )
|
|
|
| tokenizer = tokenizer_ext.get_tokenizer()
|
| print(f"Extended tokenizer vocab size: {tokenizer.vocab_size}")
|
|
|
| return tokenizer_ext, tokenizer
|
|
|
|
|
| def load_model(config: dict, args, tokenizer):
|
| """Load base model and apply LoRA."""
|
| base_model = config["base_model"]
|
| print(f"Loading base model: {base_model}")
|
|
|
|
|
| if args.device == "cuda" and torch.cuda.is_available():
|
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| elif args.device == "mps":
|
| dtype = torch.float32
|
| else:
|
| dtype = torch.float32
|
|
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| base_model,
|
| torch_dtype=dtype,
|
| trust_remote_code=True,
|
| )
|
|
|
|
|
| model.resize_token_embeddings(tokenizer.vocab_size)
|
|
|
|
|
| print("Applying LoRA...")
|
| lora_config = LoraConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| r=args.lora_r,
|
| lora_alpha=args.lora_alpha,
|
| lora_dropout=0.1,
|
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| bias="none",
|
| )
|
|
|
| model = get_peft_model(model, lora_config)
|
| model.print_trainable_parameters()
|
|
|
| return model
|
|
|
|
|
| def generate_synthetic_data(config: dict, args, tokenizer):
|
| """Generate synthetic training data."""
|
| from data.music_qa_generator import MusicQAGenerator
|
| from data.chat_formatter import ChatFormatter
|
|
|
| print("Generating synthetic training data...")
|
|
|
|
|
| generator = MusicQAGenerator(seed=42)
|
|
|
|
|
| output_dir = Path(args.data_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| dataset = generator.generate_dataset(
|
| num_samples=args.num_train_samples,
|
| output_path=output_dir / "synthetic_music_qa.jsonl",
|
| )
|
|
|
|
|
| formatter = ChatFormatter(tokenizer=tokenizer)
|
| formatted_samples = []
|
|
|
| for item in dataset:
|
| formatted = formatter.format_qa_pair(
|
| question=item["messages"][1]["content"],
|
| answer=item["messages"][2]["content"],
|
| context=None,
|
| )
|
| formatted_samples.append(formatted)
|
|
|
|
|
| splits = formatter.create_pretraining_dataset(
|
| formatted_samples,
|
| output_dir=output_dir,
|
| train_split=0.9,
|
| )
|
|
|
| print(f"Data generation complete. Train: {splits['train']}, Val: {splits['val']}")
|
|
|
| return splits
|
|
|
|
|
| def load_datasets(args, tokenizer):
|
| """Load training and validation datasets."""
|
| data_dir = Path(args.data_dir)
|
|
|
| train_path = data_dir / "train.jsonl"
|
| val_path = data_dir / "val.jsonl"
|
|
|
| if not train_path.exists() or not val_path.exists():
|
| print(f"Data not found in {data_dir}. Generate with --generate_data")
|
| sys.exit(1)
|
|
|
| print(f"Loading datasets from {data_dir}")
|
|
|
| train_dataset = TouchGrassDataset(
|
| data_path=str(train_path),
|
| tokenizer=tokenizer,
|
| max_seq_length=4096,
|
| mode="train",
|
| )
|
|
|
| val_dataset = TouchGrassDataset(
|
| data_path=str(val_path),
|
| tokenizer=tokenizer,
|
| max_seq_length=4096,
|
| mode="eval",
|
| )
|
|
|
| return train_dataset, val_dataset
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| if args.model_size == "3b":
|
| model_config = TOUCHGRASS_3B_CONFIG.copy()
|
| train_config = TRAINING_CONFIG_3B_CUDA.copy()
|
| else:
|
| model_config = TOUCHGRASS_7B_CONFIG.copy()
|
| train_config = TRAINING_CONFIG_7B_CUDA.copy()
|
|
|
|
|
| if args.use_mps or args.device == "mps":
|
| train_config = TRAINING_CONFIG_MPS.copy()
|
| train_config["use_mps"] = True
|
|
|
|
|
| if args.max_steps:
|
| train_config["max_steps"] = args.max_steps
|
| if args.micro_batch_size:
|
| train_config["micro_batch_size"] = args.micro_batch_size
|
|
|
|
|
| device = torch.device(args.device)
|
| train_config["device"] = args.device
|
|
|
| print(f"Training TouchGrass-{args.model_size.upper()}")
|
| print(f"Device: {device}")
|
| print(f"Max steps: {train_config['max_steps']}")
|
| print(f"Micro batch size: {train_config['micro_batch_size']}")
|
| print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}")
|
|
|
|
|
| tokenizer_ext, tokenizer = load_tokenizer(model_config, args)
|
|
|
|
|
| if args.generate_data:
|
| generate_synthetic_data(model_config, args, tokenizer)
|
|
|
|
|
| train_dataset, val_dataset = load_datasets(args, tokenizer)
|
| print(f"Training samples: {len(train_dataset)}")
|
| print(f"Validation samples: {len(val_dataset)}")
|
|
|
|
|
| model = load_model(model_config, args, tokenizer)
|
|
|
|
|
| trainer = TouchGrassTrainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| train_dataset=train_dataset,
|
| config=train_config,
|
| eval_dataset=val_dataset,
|
| )
|
|
|
|
|
| if args.resume_from_checkpoint:
|
| trainer.load_checkpoint(args.resume_from_checkpoint)
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| output_dir = Path(args.output_dir) / f"touchgrass-{args.model_size}b-final"
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| print(f"\nSaving final model to {output_dir}")
|
| model.save_pretrained(output_dir)
|
| tokenizer.save_pretrained(output_dir)
|
|
|
|
|
| tokenizer_ext.save_pretrained(output_dir)
|
|
|
| print("Training complete! Model saved.")
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |