Pomilon
Deploy Aetheris to HF Space
1df0e33
import torch
import time
import os
from aetheris.utils import save_checkpoint, load_latest_checkpoint, calculate_model_stats
class Trainer:
def __init__(self, model, optimizer, scaler, config, device, checkpoint_dir, logger=None):
self.model = model
self.optimizer = optimizer
self.scaler = scaler
self.config = config
self.device = device
self.checkpoint_dir = checkpoint_dir
self.logger = logger
self.model.to(self.device)
def validate(self, val_loader, global_step):
self.model.eval()
total_loss = 0
total_items = 0
num_batches = 100 # Validate on 100 batches to save time
print(f"\n[Validation] Starting validation at step {global_step}...")
with torch.no_grad():
for i, batch in enumerate(val_loader):
if i >= num_batches:
break
input_ids, labels = batch
input_ids = input_ids.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
# Auto-cast context
if self.device.type == 'cuda':
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
use_autocast = True if self.config.torch_dtype != torch.float32 else False
if use_autocast:
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
output = self.model(input_ids, labels)
else:
output = self.model(input_ids, labels)
total_loss += output["loss"].item()
total_items += 1
avg_loss = total_loss / total_items if total_items > 0 else 0
perplexity = torch.exp(torch.tensor(avg_loss)).item()
print(f"[Validation] Step {global_step} | Loss: {avg_loss:.4f} | PPL: {perplexity:.4f}")
self.model.train()
return avg_loss
def train_epoch(self, train_loader, total_steps, start_step=0, stage_name="Training", val_loader=None, eval_every=500):
print(f"\n{'='*70}\nStarting {stage_name}: Target Steps={total_steps}\n{'='*70}")
self.model.train()
global_step = start_step
running_loss = 0
print("Initializing data iterator...")
train_iter = iter(train_loader)
print("Fetching first batch...")
while global_step < total_steps:
step_start = time.time()
# Removed periodic cache clearing for performance
self.optimizer.zero_grad(set_to_none=True)
try:
batch = next(train_iter)
if global_step == start_step:
print(f"✓ First batch loaded! Starting training loop...")
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
input_ids, labels = batch
input_ids = input_ids.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
# Determine autocast dtype
if self.device.type == 'cuda':
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
# Check if we should use autocast (skip if model uses float32)
use_autocast = True
if self.config.torch_dtype == torch.float32:
use_autocast = False
if use_autocast:
with torch.amp.autocast('cuda' if self.device.type == 'cuda' else 'cpu', dtype=autocast_dtype):
output = self.model(input_ids, labels)
loss = output["loss"]
else:
output = self.model(input_ids, labels)
loss = output["loss"]
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
print(f"WARNING: NaN/Inf gradient at step {global_step}, skipping update")
else:
self.scaler.step(self.optimizer)
self.scaler.update()
global_step += 1
running_loss += loss.item()
if global_step % 10 == 0:
avg_loss = running_loss / 10
t_diff = time.time() - step_start
if self.device.type == 'cuda':
mem = torch.cuda.memory_allocated() / 1e9
max_mem = torch.cuda.max_memory_allocated() / 1e9
mem_str = f"VRAM: {mem:.1f}GB (peak: {max_mem:.1f}GB)"
else:
mem_str = "CPU Mode"
tokens_per_sec = (self.config.max_seq_len * input_ids.size(0)) / t_diff
print(f" Step {global_step}/{total_steps} | Loss: {avg_loss:.4f} | "
f"{mem_str} | {tokens_per_sec:.0f} tok/s")
running_loss = 0
if global_step % 500 == 0:
save_checkpoint(self.model, self.optimizer, self.scaler, global_step, stage_name, self.checkpoint_dir)
if val_loader is not None and global_step % eval_every == 0 and global_step > start_step:
self.validate(val_loader, global_step)
return global_step