Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.optim as optim | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from datasets import load_from_disk | |
| import os | |
| from config import Config | |
| from utils.tokenizer import build_vocab | |
| from utils.preprocessing import collate_fn | |
| from models.seq2seq import Encoder, Decoder, Seq2Seq | |
| from tqdm import tqdm | |
| def save_checkpoint(epoch, model, optimizer, scaler, loss, path): | |
| """Save training checkpoint""" | |
| checkpoint = { | |
| 'epoch': epoch, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scaler_state_dict': scaler.state_dict(), | |
| 'loss': loss, | |
| } | |
| torch.save(checkpoint, path) | |
| print(f"β Checkpoint saved at epoch {epoch}") | |
| def load_checkpoint(model, optimizer, scaler, path, device): | |
| """Load training checkpoint""" | |
| if os.path.exists(path): | |
| checkpoint = torch.load(path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| scaler.load_state_dict(checkpoint['scaler_state_dict']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| best_loss = checkpoint['loss'] | |
| print(f"β Checkpoint loaded. Resuming from epoch {start_epoch}") | |
| return start_epoch, best_loss | |
| return 0, float('inf') # Start from beginning if no checkpoint | |
| def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg): | |
| model.train() | |
| total_loss = 0 | |
| optimizer.zero_grad() # Zero gradients at start | |
| loop = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False) | |
| for batch_idx, (src, trg) in enumerate(loop): | |
| src, trg = src.to(device), trg.to(device) | |
| # Mixed precision training | |
| with torch.cuda.amp.autocast(enabled=cfg.use_amp): | |
| output = model(src, trg) | |
| output_dim = output.shape[-1] | |
| output = output[1:].reshape(-1, output_dim) | |
| trg = trg[1:].reshape(-1) | |
| loss = criterion(output, trg) / cfg.gradient_accumulation_steps # Normalize loss | |
| scaler.scale(loss).backward() | |
| # Gradient accumulation | |
| if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0: | |
| if cfg.use_gradient_clipping: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| total_loss += loss.item() * cfg.gradient_accumulation_steps | |
| loop.set_postfix(loss=loss.item() * cfg.gradient_accumulation_steps) | |
| return total_loss / len(dataloader) | |
| def main(): | |
| cfg = Config() | |
| device = cfg.device | |
| print(f"Using device: {device}") | |
| # Create directories if they don't exist | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("models/tokenizers", exist_ok=True) | |
| # Load dataset (full dataset) | |
| dataset = load_from_disk("data/raw/") | |
| # Build vocab using full dataset | |
| src_tokenizer, src_vocab = build_vocab(dataset, cfg.source_lang) | |
| trg_tokenizer, trg_vocab = build_vocab(dataset, cfg.target_lang) | |
| # Save tokenizers and vocab for future use | |
| torch.save({ | |
| 'src_tokenizer': src_tokenizer, | |
| 'src_vocab': src_vocab, | |
| 'trg_tokenizer': trg_tokenizer, | |
| 'trg_vocab': trg_vocab | |
| }, cfg.tokenizer_save_path + "tokenizers.pth") | |
| # DataLoader with train split | |
| collate = lambda batch: collate_fn( | |
| batch, src_tokenizer, trg_tokenizer, src_vocab, trg_vocab, cfg.max_length, | |
| src_lang=cfg.source_lang, trg_lang=cfg.target_lang | |
| ) | |
| dataloader = DataLoader(dataset["train"], batch_size=cfg.batch_size, collate_fn=collate, shuffle=True) | |
| # Model | |
| enc = Encoder(len(src_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers) | |
| dec = Decoder(len(trg_vocab), cfg.embedding_dim, cfg.hidden_dim, cfg.num_layers) | |
| model = Seq2Seq(enc, dec, device).to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate) | |
| criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"]) | |
| scaler = torch.cuda.amp.GradScaler() | |
| # Try to load checkpoint | |
| start_epoch, best_loss = load_checkpoint(model, optimizer, scaler, cfg.checkpoint_path, device) | |
| for epoch in range(start_epoch, cfg.num_epochs): | |
| print(f"\nEpoch {epoch+1}/{cfg.num_epochs}") | |
| try: | |
| loss = train_one_epoch(model, dataloader, optimizer, criterion, device, scaler, epoch, cfg) | |
| print(f"Epoch {epoch+1}/{cfg.num_epochs} | Loss: {loss:.3f}") | |
| # Save checkpoint after each epoch | |
| save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path) | |
| # Save best model | |
| if loss < best_loss: | |
| best_loss = loss | |
| torch.save(model.state_dict(), cfg.best_model_path) | |
| print(f"π New best model saved with loss: {loss:.3f}") | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| print("β οΈ GPU out of memory. Saving checkpoint and exiting...") | |
| save_checkpoint(epoch, model, optimizer, scaler, loss, cfg.checkpoint_path) | |
| print("β Checkpoint saved. You can resume training later.") | |
| break | |
| else: | |
| raise e | |
| print("β Training completed!") | |
| if __name__ == "__main__": | |
| main() |