Spaces:
Sleeping
Sleeping
| import torch | |
| import time | |
| import os | |
| from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats | |
| class Trainer: | |
| def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None): | |
| self.model = model | |
| self.optimizer = optimizer | |
| self.scaler = scaler | |
| self.config = config | |
| self.device = device | |
| self.checkpoint_dir = checkpoint_dir | |
| self.logger = logger | |
| self.model.to(self.device) | |
| def validate(self, val_loader, global_step): | |
| self.model.eval() | |
| total_loss = 0 | |
| total_items = 0 | |
| num_batches = 100 # Validate on 100 batches to save time | |
| print(f"\n[Validation] Starting validation at step {global_step}...") | |
| with torch.no_grad(): | |
| for i, batch in enumerate(val_loader): | |
| if i >= num_batches: | |
| break | |
| input_ids, labels = batch | |
| input_ids = input_ids.to(self.device, non_blocking=True) | |
| labels = labels.to(self.device, non_blocking=True) | |
| # Auto-cast context | |
| if self.device.type == 'cuda': | |
| autocast_dtype = torch.float16 | |
| else: | |
| autocast_dtype = torch.bfloat16 | |
| use_autocast = True if self.config.torch_dtype != torch.float32 else False | |
| if use_autocast: | |
| with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype): | |
| output = self.model(input_ids, labels) | |
| else: | |
| output = self.model(input_ids, labels) | |
| total_loss += output["loss"].item() | |
| total_items += 1 | |
| avg_loss = total_loss / total_items if total_items > 0 else 0 | |
| perplexity = torch.exp(torch.tensor(avg_loss)).item() | |
| print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}") | |
| self.model.train() | |
| return avg_loss | |
| def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500): | |
| print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}") | |
| self.model.train() | |
| global_step = start_step | |
| running_loss = 0 | |
| print("Initializing data iterator...") | |
| train_iter = iter(train_loader) | |
| print("Fetching first batch...") | |
| while global_step < total_steps: | |
| step_start = time.time() | |
| # Removed periodic cache clearing for performance | |
| self.optimizer.zero_grad(set_to_none=True) | |
| try: | |
| batch = next(train_iter) | |
| if global_step == start_step: | |
| print(f"✓ First batch loaded! Starting training loop...") | |
| except StopIteration: | |
| train_iter = iter(train_loader) | |
| batch = next(train_iter) | |
| input_ids, labels = batch | |
| input_ids = input_ids.to(self.device, non_blocking=True) | |
| labels = labels.to(self.device, non_blocking=True) | |
| # Determine autocast dtype | |
| if self.device.type == 'cuda': | |
| autocast_dtype = torch.float16 | |
| else: | |
| autocast_dtype = torch.bfloat16 | |
| # Check if we should use autocast (skip if model uses float32) | |
| use_autocast = True | |
| if self.config.torch_dtype == torch.float32: | |
| use_autocast = False | |
| if use_autocast: | |
| with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype): | |
| output = self.model(input_ids, labels) | |
| loss = output["loss"] | |
| else: | |
| output = self.model(input_ids, labels) | |
| loss = output["loss"] | |
| self.scaler.scale(loss).backward() | |
| self.scaler.unscale_(self.optimizer) | |
| # Gradient clipping | |
| grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5) | |
| if torch.isnan(grad_norm) or torch.isinf(grad_norm): | |
| print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update") | |
| else: | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| global_step += 1 | |
| running_loss += loss.item() | |
| if global_step % 10 == 0: | |
| avg_loss = running_loss / 10 | |
| t_diff = time.time() - step_start | |
| if self.device.type == 'cuda': | |
| mem = torch.cuda.memory_allocated() / 1e9 | |
| max_mem = torch.cuda.max_memory_allocated() / 1e9 | |
| mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)" | |
| else: | |
| mem_str = "CPU Mode" | |
| tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff | |
| print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | " | |
| f"{mem_str} | {tokens_per_sec:.0f} tok/s") | |
| running_loss = 0 | |
| if global_step % 500 == 0: | |
| save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir) | |
| if val_loader is not None and global_step % eval_every == 0 and global_step > start_step: | |
| self.validate(val_loader, global_step) | |
| return global_step | |