TouchGrass-7b / train.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
#!/usr/bin/env python3
"""
Main training entry point for TouchGrass models.
Fine-tunes Qwen3.5 with LoRA and music modules.
"""
import argparse
import sys
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
from configs.training_config import (
TRAINING_CONFIG_3B_CUDA,
TRAINING_CONFIG_7B_CUDA,
TRAINING_CONFIG_MPS,
)
from data.dataset_loader import TouchGrassDataset
from training.trainer import TouchGrassTrainer
from tokenizer.music_token_extension import MusicTokenizerExtension
def parse_args():
parser = argparse.ArgumentParser(description="Train TouchGrass music assistant model")
parser.add_argument(
"--model_size",
type=str,
choices=["3b", "7b"],
default="3b",
help="Model size to train",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "mps", "cpu"],
help="Device to train on",
)
parser.add_argument(
"--use_mps",
action="store_true",
help="Use MPS backend (Apple Silicon)",
)
parser.add_argument(
"--data_dir",
type=str,
default="./data/processed",
help="Directory with processed data shards",
)
parser.add_argument(
"--output_dir",
type=str,
default="./checkpoints",
help="Output directory for checkpoints",
)
parser.add_argument(
"--max_steps",
type=int,
default=None,
help="Override max training steps",
)
parser.add_argument(
"--micro_batch_size",
type=int,
default=None,
help="Override micro batch size",
)
parser.add_argument(
"--lora_r",
type=int,
default=16,
help="LoRA rank",
)
parser.add_argument(
"--lora_alpha",
type=int,
default=32,
help="LoRA alpha",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="Resume training from checkpoint",
)
parser.add_argument(
"--generate_data",
action="store_true",
help="Generate synthetic training data before training",
)
parser.add_argument(
"--num_train_samples",
type=int,
default=10000,
help="Number of training samples to generate",
)
return parser.parse_args()
def load_tokenizer(config: dict, args):
"""Load and extend tokenizer with music tokens."""
base_model = config["base_model"]
print(f"Loading base tokenizer: {base_model}")
# Extend tokenizer with music tokens
tokenizer_ext = MusicTokenizerExtension(
base_tokenizer_name=base_model,
special_tokens=config.get("special_tokens"),
)
tokenizer = tokenizer_ext.get_tokenizer()
print(f"Extended tokenizer vocab size: {tokenizer.vocab_size}")
return tokenizer_ext, tokenizer
def load_model(config: dict, args, tokenizer):
"""Load base model and apply LoRA."""
base_model = config["base_model"]
print(f"Loading base model: {base_model}")
# Determine torch dtype
if args.device == "cuda" and torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
elif args.device == "mps":
dtype = torch.float32 # MPS doesn't support bf16 well
else:
dtype = torch.float32
# Load model
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=dtype,
trust_remote_code=True,
)
# Resize embeddings to match extended tokenizer
model.resize_token_embeddings(tokenizer.vocab_size)
# Apply LoRA
print("Applying LoRA...")
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model
def generate_synthetic_data(config: dict, args, tokenizer):
"""Generate synthetic training data."""
from data.music_qa_generator import MusicQAGenerator
from data.chat_formatter import ChatFormatter
print("Generating synthetic training data...")
# Create generator
generator = MusicQAGenerator(seed=42)
# Generate dataset
output_dir = Path(args.data_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Generate full dataset
dataset = generator.generate_dataset(
num_samples=args.num_train_samples,
output_path=output_dir / "synthetic_music_qa.jsonl",
)
# Format with chat formatter
formatter = ChatFormatter(tokenizer=tokenizer)
formatted_samples = []
for item in dataset:
formatted = formatter.format_qa_pair(
question=item["messages"][1]["content"],
answer=item["messages"][2]["content"],
context=None, # Context already in question
)
formatted_samples.append(formatted)
# Create train/val splits
splits = formatter.create_pretraining_dataset(
formatted_samples,
output_dir=output_dir,
train_split=0.9,
)
print(f"Data generation complete. Train: {splits['train']}, Val: {splits['val']}")
return splits
def load_datasets(args, tokenizer):
"""Load training and validation datasets."""
data_dir = Path(args.data_dir)
train_path = data_dir / "train.jsonl"
val_path = data_dir / "val.jsonl"
if not train_path.exists() or not val_path.exists():
print(f"Data not found in {data_dir}. Generate with --generate_data")
sys.exit(1)
print(f"Loading datasets from {data_dir}")
train_dataset = TouchGrassDataset(
data_path=str(train_path),
tokenizer=tokenizer,
max_seq_length=4096,
mode="train",
)
val_dataset = TouchGrassDataset(
data_path=str(val_path),
tokenizer=tokenizer,
max_seq_length=4096,
mode="eval",
)
return train_dataset, val_dataset
def main():
args = parse_args()
# Load config
if args.model_size == "3b":
model_config = TOUCHGRASS_3B_CONFIG.copy()
train_config = TRAINING_CONFIG_3B_CUDA.copy()
else:
model_config = TOUCHGRASS_7B_CONFIG.copy()
train_config = TRAINING_CONFIG_7B_CUDA.copy()
# Override with MPS config if needed
if args.use_mps or args.device == "mps":
train_config = TRAINING_CONFIG_MPS.copy()
train_config["use_mps"] = True
# Apply overrides
if args.max_steps:
train_config["max_steps"] = args.max_steps
if args.micro_batch_size:
train_config["micro_batch_size"] = args.micro_batch_size
# Set device
device = torch.device(args.device)
train_config["device"] = args.device
print(f"Training TouchGrass-{args.model_size.upper()}")
print(f"Device: {device}")
print(f"Max steps: {train_config['max_steps']}")
print(f"Micro batch size: {train_config['micro_batch_size']}")
print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}")
# Load tokenizer
tokenizer_ext, tokenizer = load_tokenizer(model_config, args)
# Generate data if requested
if args.generate_data:
generate_synthetic_data(model_config, args, tokenizer)
# Load datasets
train_dataset, val_dataset = load_datasets(args, tokenizer)
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
# Load model with LoRA
model = load_model(model_config, args, tokenizer)
# Create trainer
trainer = TouchGrassTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=train_config,
eval_dataset=val_dataset,
)
# Resume from checkpoint if specified
if args.resume_from_checkpoint:
trainer.load_checkpoint(args.resume_from_checkpoint)
# Train
trainer.train()
# Save final model
output_dir = Path(args.output_dir) / f"touchgrass-{args.model_size}b-final"
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\nSaving final model to {output_dir}")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# Save tokenizer extension metadata
tokenizer_ext.save_pretrained(output_dir)
print("Training complete! Model saved.")
if __name__ == "__main__":
main()