import torch import pandas as pd import time from torch.nn.utils import clip_grad_norm_ def train(model, train_loader, val_loader, criterion, optimizer, epochs, device, version, max_grad_norm=1.0, early_stopping_patience=5, early_stopping_delta=0.001): best_accuracy = 0.0 best_model_path = f'./output/version_{version}/best_model_{version}.pth' best_epoch = 0 early_stopping_counter = 0 total_batches = len(train_loader) metrics = { 'epoch': [], 'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': [] } for epoch in range(epochs): model.train() total_loss, train_correct, train_total = 0, 0, 0 for batch_idx, (titles, texts, labels) in enumerate(train_loader): start_time = time.time() # Start time for the batch titles, texts, labels = titles.to(device), texts.to( device), labels.to(device).float() # Forward pass outputs = model(titles, texts).squeeze() loss = criterion(outputs, labels) # Backward and optimize optimizer.zero_grad() loss.backward() if max_grad_norm: clip_grad_norm_(model.parameters(), max_norm=max_grad_norm) optimizer.step() total_loss += loss.item() train_pred = (outputs > 0.5).float() train_correct += (train_pred == labels).sum().item() train_total += labels.size(0) # Calculate and print batch processing time batch_time = time.time() - start_time print( f'Epoch: {epoch+1}, Batch: {batch_idx+1}/{total_batches}, Batch Processing Time: {batch_time:.4f} seconds') train_accuracy = 100 * train_correct / train_total metrics['train_loss'].append(total_loss / len(train_loader)) metrics['train_accuracy'].append(train_accuracy) # Validation model.eval() val_loss, val_correct, val_total = 0, 0, 0 with torch.no_grad(): for titles, texts, labels in val_loader: titles, texts, labels = titles.to(device), texts.to( device), labels.to(device).float() outputs = model(titles, texts).squeeze() loss = criterion(outputs, labels) val_loss += loss.item() predicted = (outputs > 0.5).float() val_total += labels.size(0) val_correct += (predicted == labels).sum().item() val_accuracy = 100 * val_correct / val_total metrics['val_loss'].append(val_loss / len(val_loader)) metrics['val_accuracy'].append(val_accuracy) metrics['epoch'].append(epoch + 1) # Early stopping logic if val_accuracy > best_accuracy + early_stopping_delta: best_accuracy = val_accuracy early_stopping_counter = 0 best_epoch = epoch + 1 torch.save(model.state_dict(), best_model_path) else: early_stopping_counter += 1 if early_stopping_counter >= early_stopping_patience: print(f"Early stopping triggered at epoch {epoch + 1}") break print( f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%') pd.DataFrame(metrics).to_csv( f'./output/version_{version}/training_metrics_{version}.csv', index=False) return model, best_accuracy, best_epoch