|
|
| """
|
| Training script for Zenith-7B model.
|
| Fine-tunes on OpenThoughts-1.2M with custom data for code generation and EQ.
|
| """
|
|
|
| import argparse
|
| import logging
|
| import os
|
| import sys
|
| from pathlib import Path
|
|
|
| import torch
|
| from transformers import AutoTokenizer
|
|
|
|
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from configs.zenith_config import get_7b_config, DataConfig, TrainingConfig, TrainerConfig
|
| from data.openthoughts_processor import OpenThoughtsConfig, OpenThoughtsProcessor, QualityFilter, CurriculumSampler
|
| from models.zenith_model import ZenithForCausalLM, LoRAAdapter, QLoRAAdapter
|
| from training.trainer import train_zenith_model, Trainer
|
| from utils.checkpoint import setup_logging
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description="Train Zenith-7B model")
|
| parser.add_argument("--output_dir", type=str, default="./outputs/zenith-7b", help="Output directory")
|
| parser.add_argument("--data_dir", type=str, default="./data", help="Data directory")
|
| parser.add_argument("--cache_dir", type=str, default="./cache", help="Cache directory")
|
| parser.add_argument("--log_dir", type=str, default="./logs", help="Log directory")
|
|
|
|
|
| parser.add_argument("--base_model", type=str, default="meta-llama/Llama-2-7b-hf", help="Base model to fine-tune")
|
| parser.add_argument("--use_lora", action="store_true", help="Use LoRA for efficient fine-tuning")
|
| parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
|
| parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
|
| parser.add_argument("--use_qlora", action="store_true", help="Use QLoRA (4-bit quantization)")
|
|
|
|
|
| parser.add_argument("--openthoughts_dataset", type=str, default="open-thoughts/OpenThoughts3-1.2M", help="OpenThoughts dataset")
|
| parser.add_argument("--custom_datasets", type=str, nargs="+", default=[], help="Custom dataset paths")
|
| parser.add_argument("--max_seq_length", type=int, default=8192, help="Maximum sequence length")
|
| parser.add_argument("--train_batch_size", type=int, default=4, help="Training batch size")
|
| parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
|
| parser.add_argument("--effective_batch_size", type=int, default=32, help="Effective batch size (overrides gradient_accumulation if set)")
|
|
|
|
|
| parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
|
| parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs")
|
| parser.add_argument("--max_steps", type=int, default=-1, help="Maximum training steps (-1 for epochs)")
|
| parser.add_argument("--warmup_steps", type=int, default=1000, help="Warmup steps")
|
| parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
| parser.add_argument("--clip_grad_norm", type=float, default=1.0, help="Gradient clipping norm")
|
|
|
|
|
| parser.add_argument("--use_curriculum", action="store_true", help="Enable curriculum learning")
|
| parser.add_argument("--use_quality_filter", action="store_true", help="Enable quality filtering")
|
| parser.add_argument("--use_augmentation", action="store_true", help="Enable data augmentation")
|
| parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"], help="Mixed precision")
|
| parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
|
|
| parser.add_argument("--logging_steps", type=int, default=10, help="Logging steps")
|
| parser.add_argument("--eval_steps", type=int, default=500, help="Evaluation steps")
|
| parser.add_argument("--save_steps", type=int, default=1000, help="Save checkpoint steps")
|
| parser.add_argument("--report_to", type=str, nargs="+", default=["tensorboard", "wandb"], help="Reporting platforms")
|
|
|
|
|
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Resume from checkpoint")
|
|
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
|
|
| setup_logging(log_dir=args.log_dir)
|
| logger.info("Starting Zenith-7B training")
|
| logger.info(f"Arguments: {args}")
|
|
|
|
|
| torch.manual_seed(args.seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(args.seed)
|
|
|
|
|
| os.makedirs(args.output_dir, exist_ok=True)
|
| os.makedirs(args.cache_dir, exist_ok=True)
|
|
|
|
|
| logger.info(f"Loading tokenizer: {args.base_model}")
|
| tokenizer = AutoTokenizer.from_pretrained(
|
| args.base_model,
|
| cache_dir=args.cache_dir,
|
| use_fast=True,
|
| )
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| logger.info(f"Loading base model: {args.base_model}")
|
| model_kwargs = {
|
| "cache_dir": args.cache_dir,
|
| "torch_dtype": torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16 if args.mixed_precision == "fp16" else torch.float32,
|
| "device_map": "auto" if torch.cuda.is_available() else None,
|
| }
|
|
|
| if args.use_qlora:
|
| model_kwargs["load_in_4bit"] = True
|
| model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
|
| model_kwargs["bnb_4bit_quant_type"] = "nf4"
|
| model_kwargs["bnb_4bit_use_double_quant"] = True
|
|
|
| base_model = AutoModelForCausalLM.from_pretrained(args.base_model, **model_kwargs)
|
|
|
|
|
| if args.use_lora or args.use_qlora:
|
| logger.info("Applying LoRA adapters...")
|
| lora_config = LoRAAdapter(
|
| r=args.lora_rank,
|
| lora_alpha=args.lora_alpha,
|
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| lora_dropout=0.05,
|
| bias="none",
|
| )
|
| base_model = apply_lora(base_model, lora_config)
|
|
|
|
|
| config = get_7b_config()
|
| model = ZenithForCausalLM(config, base_model=base_model)
|
|
|
| logger.info(f"Model initialized: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e9:.2f}B trainable parameters")
|
|
|
|
|
| data_config = DataConfig(
|
| openthoughts_dataset=args.openthoughts_dataset,
|
| custom_datasets=args.custom_datasets,
|
| tokenizer_name=args.base_model,
|
| max_seq_length=args.max_seq_length,
|
| use_curriculum=args.use_curriculum,
|
| use_augmentation=args.use_augmentation,
|
| cache_dir=args.cache_dir,
|
| )
|
|
|
|
|
| quality_filter = QualityFilter() if args.use_quality_filter else None
|
| data_config.quality_filter = quality_filter
|
|
|
|
|
| if args.effective_batch_size:
|
| gradient_accumulation_steps = args.effective_batch_size // args.train_batch_size
|
| else:
|
| gradient_accumulation_steps = args.gradient_accumulation_steps
|
|
|
| training_config = TrainingConfig(
|
| train_batch_size=args.train_batch_size,
|
| gradient_accumulation_steps=gradient_accumulation_steps,
|
| learning_rate=args.learning_rate,
|
| num_train_epochs=args.num_train_epochs,
|
| max_steps=args.max_steps,
|
| save_steps=args.save_steps,
|
| eval_steps=args.eval_steps,
|
| logging_steps=args.logging_steps,
|
| optimizer=type('obj', (object,), {
|
| 'type': 'adamw',
|
| 'learning_rate': args.learning_rate,
|
| 'weight_decay': args.weight_decay,
|
| 'clip_grad_norm': args.clip_grad_norm,
|
| })(),
|
| scheduler=type('obj', (object,), {
|
| 'type': 'cosine',
|
| 'warmup_steps': args.warmup_steps,
|
| })(),
|
| mixed_precision=args.mixed_precision,
|
| gradient_ckpt=True,
|
| report_to=args.report_to,
|
| seed=args.seed,
|
| resume_from_checkpoint=args.resume_from_checkpoint,
|
| )
|
|
|
|
|
| trainer_config = TrainerConfig(
|
| model_config=config,
|
| data_config=data_config,
|
| training_config=training_config,
|
| output_dir=args.output_dir,
|
| logging_dir=args.log_dir,
|
| checkpoint_dir=f"{args.output_dir}/checkpoints",
|
| gradient_accumulation_steps=gradient_accumulation_steps,
|
| use_amp=args.mixed_precision != "no",
|
| log_interval=args.logging_steps,
|
| eval_interval=args.eval_steps,
|
| save_interval=args.save_steps,
|
| resume_from_checkpoint=args.resume_from_checkpoint,
|
| )
|
|
|
|
|
| logger.info("Loading OpenThoughts dataset...")
|
| openthoughts_config = OpenThoughtsConfig(
|
| dataset_name=args.openthoughts_dataset,
|
| cache_dir=args.cache_dir,
|
| quality_filter=quality_filter,
|
| use_curriculum=args.use_curriculum,
|
| use_augmentation=args.use_augmentation,
|
| max_seq_length=args.max_seq_length,
|
| tokenizer=tokenizer,
|
| )
|
|
|
| processor = OpenThoughtsProcessor(openthoughts_config)
|
| dataset = processor.load_dataset()
|
|
|
|
|
| logger.info("Splitting dataset...")
|
| split_dataset = dataset.train_test_split(test_size=0.05, seed=args.seed)
|
| train_dataset = split_dataset["train"]
|
| val_dataset = split_dataset["test"]
|
|
|
| logger.info(f"Train samples: {len(train_dataset)}")
|
| logger.info(f"Val samples: {len(val_dataset)}")
|
|
|
|
|
| if args.use_curriculum:
|
| from ..data import create_curriculum_sampler
|
| curriculum_sampler = create_curriculum_sampler(
|
| train_dataset,
|
| data_config.curriculum,
|
| current_epoch=0,
|
| seed=args.seed,
|
| )
|
| if curriculum_sampler:
|
|
|
| pass
|
|
|
|
|
| logger.info("Starting training...")
|
| trainer = train_zenith_model(
|
| model=model,
|
| tokenizer=tokenizer,
|
| config=trainer_config,
|
| train_dataset=train_dataset,
|
| val_dataset=val_dataset,
|
| )
|
|
|
| logger.info("Training complete!")
|
| logger.info(f"Model saved to {args.output_dir}")
|
|
|
|
|
| model.save_pretrained(f"{args.output_dir}/final")
|
| tokenizer.save_pretrained(f"{args.output_dir}/final")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|