ApexID / train.py
Michael Rey
added documentation and modified existing code
b0ae40a
import torch
import torch.nn as nn
import torch.optim as optim
from utils.data_loader import get_data_loaders
from models.resnet_model import MonkeyResNet
import os
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# This class helps stop training early if validation loss stops improving
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss):
if val_loss < self.best_loss:
self.best_loss = val_loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
# Hyperparameters
data_dir = "data"
epochs = 50
batch_size = 32
lr = 0.001
patience = 5
# Load training and validation data
train_loader, val_loader, class_names = get_data_loaders(data_dir, batch_size)
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Calculate class weights
train_labels = []
for _, labels in train_loader:
train_labels.extend(labels.numpy())
train_labels = np.array(train_labels)
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(train_labels),
y=train_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
# Set up model, loss function, optimizer, scheduler
model = MonkeyResNet(num_classes=len(class_names)).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
early_stopper = EarlyStopping(patience=patience)
# Store values for plotting
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
# Start training loop
for epoch in range(epochs):
model.train()
train_loss = 0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_accuracy = 100 * correct / total
train_losses.append(train_loss)
train_accuracies.append(train_accuracy)
# Validation step
model.eval()
val_loss = 0
correct_val = 0
total_val = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).sum().item()
val_accuracy = 100 * correct_val / total_val
val_losses.append(val_loss)
val_accuracies.append(val_accuracy)
scheduler.step(val_loss)
early_stopper(val_loss)
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Train Acc: {train_accuracy:.2f}%")
if early_stopper.early_stop:
print(f"Early stopping triggered at epoch {epoch+1}")
break
# Save the trained model
os.makedirs("models", exist_ok=True)
torch.save(model.state_dict(), "models/monkey_resnet.pth")
print("Training done. Model saved.")
# Save training and validation plots
os.makedirs("plots", exist_ok=True)
# Loss plot
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig("plots/loss_plot.png")
plt.close()
# Accuracy plot
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Val Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Training and Validation Accuracy")
plt.legend()
plt.grid(True)
plt.savefig("plots/accuracy_plot.png")
plt.close()