File size: 3,526 Bytes
c5cd586 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
import pandas as pd
import time
from torch.nn.utils import clip_grad_norm_
def train(model, train_loader, val_loader, criterion, optimizer, epochs, device, version, max_grad_norm=1.0, early_stopping_patience=5, early_stopping_delta=0.001):
best_accuracy = 0.0
best_model_path = f'./output/version_{version}/best_model_{version}.pth'
best_epoch = 0
early_stopping_counter = 0
total_batches = len(train_loader)
metrics = {
'epoch': [], 'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []
}
for epoch in range(epochs):
model.train()
total_loss, train_correct, train_total = 0, 0, 0
for batch_idx, (titles, texts, labels) in enumerate(train_loader):
start_time = time.time() # Start time for the batch
titles, texts, labels = titles.to(device), texts.to(
device), labels.to(device).float()
# Forward pass
outputs = model(titles, texts).squeeze()
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
if max_grad_norm:
clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
optimizer.step()
total_loss += loss.item()
train_pred = (outputs > 0.5).float()
train_correct += (train_pred == labels).sum().item()
train_total += labels.size(0)
# Calculate and print batch processing time
batch_time = time.time() - start_time
print(
f'Epoch: {epoch+1}, Batch: {batch_idx+1}/{total_batches}, Batch Processing Time: {batch_time:.4f} seconds')
train_accuracy = 100 * train_correct / train_total
metrics['train_loss'].append(total_loss / len(train_loader))
metrics['train_accuracy'].append(train_accuracy)
# Validation
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
with torch.no_grad():
for titles, texts, labels in val_loader:
titles, texts, labels = titles.to(device), texts.to(
device), labels.to(device).float()
outputs = model(titles, texts).squeeze()
loss = criterion(outputs, labels)
val_loss += loss.item()
predicted = (outputs > 0.5).float()
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_accuracy = 100 * val_correct / val_total
metrics['val_loss'].append(val_loss / len(val_loader))
metrics['val_accuracy'].append(val_accuracy)
metrics['epoch'].append(epoch + 1)
# Early stopping logic
if val_accuracy > best_accuracy + early_stopping_delta:
best_accuracy = val_accuracy
early_stopping_counter = 0
best_epoch = epoch + 1
torch.save(model.state_dict(), best_model_path)
else:
early_stopping_counter += 1
if early_stopping_counter >= early_stopping_patience:
print(f"Early stopping triggered at epoch {epoch + 1}")
break
print(
f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%')
pd.DataFrame(metrics).to_csv(
f'./output/version_{version}/training_metrics_{version}.csv', index=False)
return model, best_accuracy, best_epoch
|