| """
|
| Trainer: Main training loop for Vortex model.
|
| Handles gradient accumulation, mixed precision, checkpointing.
|
| """
|
|
|
| import os
|
| import json
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader, Dataset
|
| from typing import Optional, Dict, List, Callable
|
| from pathlib import Path
|
| import logging
|
|
|
| from ..training.losses import VortexLoss
|
| from ..training.curriculum import CurriculumScheduler
|
|
|
|
|
| class VortexDataset(Dataset):
|
| """Simple dataset wrapper."""
|
|
|
| def __init__(
|
| self,
|
| shard_files: List[str],
|
| tokenizer,
|
| max_seq_len: int = 16384,
|
| ):
|
| """
|
| Initialize dataset.
|
|
|
| Args:
|
| shard_files: List of parquet shard files
|
| tokenizer: Tokenizer for encoding text
|
| max_seq_len: Maximum sequence length
|
| """
|
| self.shard_files = shard_files
|
| self.tokenizer = tokenizer
|
| self.max_seq_len = max_seq_len
|
|
|
|
|
| self.samples = []
|
| self._load_shards()
|
|
|
| def _load_shards(self):
|
| """Load all shards."""
|
| import pandas as pd
|
|
|
| for shard in self.shard_files:
|
| df = pd.read_parquet(shard)
|
| for _, row in df.iterrows():
|
| self.samples.append({
|
| "text": row["text"],
|
| "dataset": row.get("dataset", ""),
|
| "domain": row.get("domain", ""),
|
| })
|
|
|
| def __len__(self) -> int:
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx) -> Dict:
|
| sample = self.samples[idx]
|
| text = sample["text"]
|
|
|
|
|
| encoding = self.tokenizer.encode(
|
| text,
|
| add_special_tokens=True,
|
| return_tensors="pt",
|
| )
|
|
|
| input_ids = encoding["input_ids"].squeeze(0)
|
| attention_mask = encoding["attention_mask"].squeeze(0)
|
|
|
|
|
| if len(input_ids) > self.max_seq_len:
|
| input_ids = input_ids[:self.max_seq_len]
|
| attention_mask = attention_mask[:self.max_seq_len]
|
|
|
|
|
| labels = input_ids.clone()
|
|
|
| return {
|
| "input_ids": input_ids,
|
| "attention_mask": attention_mask,
|
| "labels": labels,
|
| "domain": sample["domain"],
|
| }
|
|
|
|
|
| class VortexTrainer:
|
| """
|
| Main trainer for Vortex model.
|
| """
|
|
|
| def __init__(
|
| self,
|
| model: nn.Module,
|
| tokenizer,
|
| train_dataset: Dataset,
|
| config: Dict,
|
| eval_dataset: Optional[Dataset] = None,
|
| optimizer: Optional[torch.optim.Optimizer] = None,
|
| scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
| ):
|
| """
|
| Initialize trainer.
|
|
|
| Args:
|
| model: VortexModel
|
| tokenizer: VortexScienceTokenizer
|
| train_dataset: Training dataset
|
| config: Training configuration
|
| eval_dataset: Optional evaluation dataset
|
| optimizer: Optional optimizer (created if None)
|
| scheduler: Optional LR scheduler
|
| """
|
| self.model = model
|
| self.tokenizer = tokenizer
|
| self.train_dataset = train_dataset
|
| self.eval_dataset = eval_dataset
|
| self.config = config
|
|
|
| self.device = torch.device(config["device"])
|
| self.use_amp = config.get("use_amp", True)
|
| self.amp_dtype = getattr(torch, config.get("amp_dtype", "bfloat16"))
|
|
|
|
|
| self.model.to(self.device)
|
|
|
|
|
| if optimizer is None:
|
| self.optimizer = self._create_optimizer()
|
| else:
|
| self.optimizer = optimizer
|
|
|
|
|
| if scheduler is None:
|
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| self.optimizer,
|
| T_max=config["max_steps"],
|
| )
|
| else:
|
| self.scheduler = scheduler
|
|
|
|
|
| self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and self.device.type == "cuda" else None
|
|
|
|
|
| self.loss_fn = VortexLoss(config)
|
|
|
|
|
| self.curriculum = CurriculumScheduler(config, config["max_steps"])
|
|
|
|
|
| self.log_dir = Path(config.get("log_dir", "logs"))
|
| self.log_dir.mkdir(parents=True, exist_ok=True)
|
| self.log_interval = config.get("log_interval", 100)
|
|
|
|
|
| self.checkpoint_dir = Path(config.get("checkpoint_dir", "checkpoints"))
|
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| self.save_interval = config.get("save_interval", 5000)
|
|
|
|
|
| self.global_step = 0
|
| self.best_eval_loss = float('inf')
|
|
|
|
|
| self.train_loader = DataLoader(
|
| train_dataset,
|
| batch_size=config["micro_batch_size"],
|
| shuffle=True,
|
| num_workers=config.get("num_workers", 4),
|
| pin_memory=config.get("pin_memory", True),
|
| prefetch_factor=config.get("prefetch_factor", 2),
|
| )
|
|
|
| if eval_dataset:
|
| self.eval_loader = DataLoader(
|
| eval_dataset,
|
| batch_size=config["micro_batch_size"],
|
| shuffle=False,
|
| num_workers=config.get("num_workers", 4),
|
| )
|
|
|
| def _create_optimizer(self) -> torch.optim.Optimizer:
|
| """Create AdamW optimizer."""
|
| return torch.optim.AdamW(
|
| self.model.parameters(),
|
| lr=self.config["learning_rate"],
|
| betas=(self.config["beta1"], self.config["beta2"]),
|
| weight_decay=self.config["weight_decay"],
|
| )
|
|
|
| def train_step(
|
| self,
|
| batch: Dict,
|
| current_step: int,
|
| ) -> Dict[str, torch.Tensor]:
|
| """
|
| Single training step.
|
|
|
| Args:
|
| batch: Batch dictionary
|
| current_step: Current step number
|
|
|
| Returns:
|
| Dictionary of losses
|
| """
|
| self.model.train()
|
|
|
|
|
| input_ids = batch["input_ids"].to(self.device)
|
| attention_mask = batch["attention_mask"].to(self.device)
|
| labels = batch["labels"].to(self.device)
|
|
|
|
|
| domain_ids = None
|
| domain_tags = None
|
|
|
|
|
| with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| outputs = self.model(
|
| input_ids=input_ids,
|
| attention_mask=attention_mask,
|
| domain_ids=domain_ids,
|
| domain_tags=domain_tags,
|
| return_dict=True,
|
| )
|
| logits = outputs["logits"]
|
|
|
|
|
| losses = self.loss_fn(
|
| logits=logits,
|
| labels=labels,
|
|
|
| )
|
|
|
|
|
| if self.scaler:
|
| self.scaler.scale(losses["total_loss"]).backward()
|
| else:
|
| losses["total_loss"].backward()
|
|
|
| return losses
|
|
|
| def train_epoch(self):
|
| """Train for one epoch."""
|
| self.model.train()
|
|
|
| for batch_idx, batch in enumerate(self.train_loader):
|
|
|
| losses = self.train_step(batch, self.global_step)
|
|
|
|
|
| if (self.global_step + 1) % self.config["gradient_accumulation_steps"] == 0:
|
|
|
| if self.config.get("clip_grad_norm", 0) > 0:
|
| if self.scaler:
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(
|
| self.model.parameters(),
|
| self.config["clip_grad_norm"],
|
| )
|
|
|
|
|
| if self.scaler:
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| else:
|
| self.optimizer.step()
|
|
|
| self.optimizer.zero_grad()
|
| self.scheduler.step()
|
|
|
|
|
| if self.global_step % self.log_interval == 0:
|
| self._log_losses(losses, batch_idx)
|
|
|
|
|
| if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
|
| self.evaluate()
|
|
|
|
|
| if self.global_step % self.save_interval == 0:
|
| self.save_checkpoint()
|
|
|
| self.global_step += 1
|
|
|
| if self.global_step >= self.config["max_steps"]:
|
| print("Reached max steps")
|
| return
|
|
|
| def evaluate(self) -> Dict[str, float]:
|
| """Run evaluation."""
|
| self.model.eval()
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| with torch.no_grad():
|
| for batch in self.eval_loader:
|
| input_ids = batch["input_ids"].to(self.device)
|
| attention_mask = batch["attention_mask"].to(self.device)
|
| labels = batch["labels"].to(self.device)
|
|
|
| with torch.cuda.amp.autocast(enabled=self.use_amp and self.device.type == "cuda"):
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| logits = outputs["logits"]
|
| loss = F.cross_entropy(
|
| logits.view(-1, logits.size(-1)),
|
| labels.view(-1),
|
| ignore_index=-100,
|
| )
|
|
|
| total_loss += loss.item()
|
| num_batches += 1
|
|
|
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| print(f"Evaluation at step {self.global_step}: loss = {avg_loss:.4f}")
|
|
|
| return {"eval_loss": avg_loss}
|
|
|
| def save_checkpoint(self, is_best: bool = False):
|
| """Save model checkpoint."""
|
| checkpoint = {
|
| "step": self.global_step,
|
| "model_state_dict": self.model.state_dict(),
|
| "optimizer_state_dict": self.optimizer.state_dict(),
|
| "scheduler_state_dict": self.scheduler.state_dict(),
|
| "config": self.config,
|
| "best_eval_loss": self.best_eval_loss,
|
| }
|
|
|
| if self.scaler:
|
| checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
|
|
|
|
| checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step:06d}.pt"
|
| torch.save(checkpoint, checkpoint_path)
|
| print(f"Saved checkpoint to {checkpoint_path}")
|
|
|
|
|
| if is_best:
|
| best_path = self.checkpoint_dir / "best_model.pt"
|
| torch.save(checkpoint, best_path)
|
| print(f"Saved best model to {best_path}")
|
|
|
|
|
| latest_path = self.checkpoint_dir / "latest.pt"
|
| torch.save(checkpoint, latest_path)
|
|
|
| def load_checkpoint(self, checkpoint_path: str):
|
| """Load checkpoint."""
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| self.model.load_state_dict(checkpoint["model_state_dict"])
|
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| self.global_step = checkpoint["step"]
|
| self.best_eval_loss = checkpoint.get("best_eval_loss", float('inf'))
|
|
|
| if self.scaler and "scaler_state_dict" in checkpoint:
|
| self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
|
|
| print(f"Loaded checkpoint from {checkpoint_path} at step {self.global_step}")
|
|
|
| def _log_losses(self, losses: Dict[str, torch.Tensor], batch_idx: int):
|
| """Log losses to console and file."""
|
| loss_str = " | ".join([f"{k}: {v.item():.4f}" for k, v in losses.items()])
|
| print(f"Step {self.global_step} | {loss_str}")
|
|
|
| def train(self):
|
| """Main training loop."""
|
| print("Starting training...")
|
| print(f"Total steps: {self.config['max_steps']}")
|
| print(f"Device: {self.device}")
|
| print(f"Batch size: {self.config['micro_batch_size']}")
|
| print(f"Gradient accumulation steps: {self.config['gradient_accumulation_steps']}")
|
|
|
| try:
|
| self.train_epoch()
|
| except KeyboardInterrupt:
|
| print("Training interrupted")
|
| finally:
|
| self.save_checkpoint()
|
|
|
|
|
| def test_trainer():
|
| """Test trainer with small model."""
|
| from models.vortex_model import VortexModel
|
| from tokenizer.vortex_tokenizer import VortexScienceTokenizer
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
|
|
| config = VORTEX_7B_CONFIG.copy()
|
| config["d_model"] = 256
|
| config["num_layers"] = 2
|
| config["num_heads"] = 4
|
| config["vocab_size"] = 1000
|
| config["max_steps"] = 10
|
| config["device"] = "cpu"
|
|
|
|
|
| model = VortexModel(config)
|
|
|
|
|
| class DummyTokenizer:
|
| def encode(self, text, add_special_tokens=True, return_tensors="pt"):
|
| return {"input_ids": torch.randint(0, 1000, (1, 10)), "attention_mask": torch.ones(1, 10)}
|
|
|
| tokenizer = DummyTokenizer()
|
|
|
|
|
| class DummyDataset(torch.utils.data.Dataset):
|
| def __len__(self): return 10
|
| def __getitem__(self, idx):
|
| return {
|
| "input_ids": torch.randint(0, 1000, (32,)),
|
| "attention_mask": torch.ones(32),
|
| "labels": torch.randint(0, 1000, (32,)),
|
| "domain": "physics",
|
| }
|
|
|
| train_dataset = DummyDataset()
|
| eval_dataset = DummyDataset()
|
|
|
|
|
| trainer = VortexTrainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| train_dataset=train_dataset,
|
| config=config,
|
| eval_dataset=eval_dataset,
|
| )
|
|
|
|
|
| trainer.train()
|
|
|
| print("Trainer test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_trainer()
|
|
|