Zenith-7b-V1 / train.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
#!/usr/bin/env python3
"""
Training script for Zenith-7B model.
Fine-tunes on OpenThoughts-1.2M with custom data for code generation and EQ.
"""
import argparse
import logging
import os
import sys
from pathlib import Path
import torch
from transformers import AutoTokenizer
# Add current directory to path for imports
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_7b_config, DataConfig, TrainingConfig, TrainerConfig
from data.openthoughts_processor import OpenThoughtsConfig, OpenThoughtsProcessor, QualityFilter, CurriculumSampler
from models.zenith_model import ZenithForCausalLM, LoRAAdapter, QLoRAAdapter
from training.trainer import train_zenith_model, Trainer
from utils.checkpoint import setup_logging
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train Zenith-7B model")
parser.add_argument("--output_dir", type=str, default="./outputs/zenith-7b", help="Output directory")
parser.add_argument("--data_dir", type=str, default="./data", help="Data directory")
parser.add_argument("--cache_dir", type=str, default="./cache", help="Cache directory")
parser.add_argument("--log_dir", type=str, default="./logs", help="Log directory")
# Model
parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-hf", help="Base model to fine-tune")
parser.add_argument("--use_lora", action="store_true", help="Use LoRA for efficient fine-tuning")
parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
parser.add_argument("--use_qlora", action="store_true", help="Use QLoRA (4-bit quantization)")
# Data
parser.add_argument("--openthoughts_dataset", type=str, default="open-thoughts/OpenThoughts3-1.2M", help="OpenThoughts dataset")
parser.add_argument("--custom_datasets", type=str, nargs="+", default=[], help="Custom dataset paths")
parser.add_argument("--max_seq_length", type=int, default=8192, help="Maximum sequence length")
parser.add_argument("--train_batch_size", type=int, default=4, help="Training batch size")
parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
parser.add_argument("--effective_batch_size", type=int, default=32, help="Effective batch size (overrides gradient_accumulation if set)")
# Training
parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--max_steps", type=int, default=-1, help="Maximum training steps (-1 for epochs)")
parser.add_argument("--warmup_steps", type=int, default=1000, help="Warmup steps")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--clip_grad_norm", type=float, default=1.0, help="Gradient clipping norm")
# Advanced
parser.add_argument("--use_curriculum", action="store_true", help="Enable curriculum learning")
parser.add_argument("--use_quality_filter", action="store_true", help="Enable quality filtering")
parser.add_argument("--use_augmentation", action="store_true", help="Enable data augmentation")
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"], help="Mixed precision")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
# Logging
parser.add_argument("--logging_steps", type=int, default=10, help="Logging steps")
parser.add_argument("--eval_steps", type=int, default=500, help="Evaluation steps")
parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint steps")
parser.add_argument("--report_to", type=str, nargs="+", default=["tensorboard", "wandb"], help="Reporting platforms")
# Resume
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
return parser.parse_args()
def main():
args = parse_args()
# Setup logging
setup_logging(log_dir=args.log_dir)
logger.info("Starting Zenith-7B training")
logger.info(f"Arguments: {args}")
# Set seed
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
# Create output directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.cache_dir, exist_ok=True)
# Load tokenizer
logger.info(f"Loading tokenizer: {args.base_model}")
tokenizer = AutoTokenizer.from_pretrained(
args.base_model,
cache_dir=args.cache_dir,
use_fast=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load base model
logger.info(f"Loading base model: {args.base_model}")
model_kwargs = {
"cache_dir": args.cache_dir,
"torch_dtype": torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16 if args.mixed_precision == "fp16" else torch.float32,
"device_map": "auto" if torch.cuda.is_available() else None,
}
if args.use_qlora:
model_kwargs["load_in_4bit"] = True
model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
model_kwargs["bnb_4bit_quant_type"] = "nf4"
model_kwargs["bnb_4bit_use_double_quant"] = True
base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
# Apply LoRA if requested
if args.use_lora or args.use_qlora:
logger.info("Applying LoRA adapters...")
lora_config = LoRAAdapter(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
)
base_model = apply_lora(base_model, lora_config)
# Create Zenith model
config = get_7b_config()
model = ZenithForCausalLM(config, base_model=base_model)
logger.info(f"Model initialized: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B trainable parameters")
# Data configuration
data_config = DataConfig(
openthoughts_dataset=args.openthoughts_dataset,
custom_datasets=args.custom_datasets,
tokenizer_name=args.base_model,
max_seq_length=args.max_seq_length,
use_curriculum=args.use_curriculum,
use_augmentation=args.use_augmentation,
cache_dir=args.cache_dir,
)
# Quality filter
quality_filter = QualityFilter() if args.use_quality_filter else None
data_config.quality_filter = quality_filter
# Training configuration
if args.effective_batch_size:
gradient_accumulation_steps = args.effective_batch_size // args.train_batch_size
else:
gradient_accumulation_steps = args.gradient_accumulation_steps
training_config = TrainingConfig(
train_batch_size=args.train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=args.learning_rate,
num_train_epochs=args.num_train_epochs,
max_steps=args.max_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
logging_steps=args.logging_steps,
optimizer=type('obj', (object,), {
'type': 'adamw',
'learning_rate': args.learning_rate,
'weight_decay': args.weight_decay,
'clip_grad_norm': args.clip_grad_norm,
})(),
scheduler=type('obj', (object,), {
'type': 'cosine',
'warmup_steps': args.warmup_steps,
})(),
mixed_precision=args.mixed_precision,
gradient_ckpt=True,
report_to=args.report_to,
seed=args.seed,
resume_from_checkpoint=args.resume_from_checkpoint,
)
# Trainer configuration
trainer_config = TrainerConfig(
model_config=config,
data_config=data_config,
training_config=training_config,
output_dir=args.output_dir,
logging_dir=args.log_dir,
checkpoint_dir=f"{args.output_dir}/checkpoints",
gradient_accumulation_steps=gradient_accumulation_steps,
use_amp=args.mixed_precision != "no",
log_interval=args.logging_steps,
eval_interval=args.eval_steps,
save_interval=args.save_steps,
resume_from_checkpoint=args.resume_from_checkpoint,
)
# Load dataset
logger.info("Loading OpenThoughts dataset...")
openthoughts_config = OpenThoughtsConfig(
dataset_name=args.openthoughts_dataset,
cache_dir=args.cache_dir,
quality_filter=quality_filter,
use_curriculum=args.use_curriculum,
use_augmentation=args.use_augmentation,
max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
)
processor = OpenThoughtsProcessor(openthoughts_config)
dataset = processor.load_dataset()
# Split dataset
logger.info("Splitting dataset...")
split_dataset = dataset.train_test_split(test_size=0.05, seed=args.seed)
train_dataset = split_dataset["train"]
val_dataset = split_dataset["test"]
logger.info(f"Train samples: {len(train_dataset)}")
logger.info(f"Val samples: {len(val_dataset)}")
# Create curriculum sampler if needed
if args.use_curriculum:
from ..data import create_curriculum_sampler
curriculum_sampler = create_curriculum_sampler(
train_dataset,
data_config.curriculum,
current_epoch=0,
seed=args.seed,
)
if curriculum_sampler:
# Will be used in dataloader creation
pass
# Train
logger.info("Starting training...")
trainer = train_zenith_model(
model=model,
tokenizer=tokenizer,
config=trainer_config,
train_dataset=train_dataset,
val_dataset=val_dataset,
)
logger.info("Training complete!")
logger.info(f"Model saved to {args.output_dir}")
# Save final model
model.save_pretrained(f"{args.output_dir}/final")
tokenizer.save_pretrained(f"{args.output_dir}/final")
if __name__ == "__main__":
main()