SDSC6001_HW3 / train.py
MingLi
add notes
d09d651
# %%
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Dataset
from torchvision.datasets import DatasetFolder
from tqdm import tqdm
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
folder = "./datasets"
NUM_CLASSES = 14
train_tfm = transforms.Compose(
[
# Resize the image to a fixed shape (height = width = 256)
transforms.Resize((224, 224)),
transforms.Lambda(lambda x: x.convert("RGB")),
# Random horizontal flip to increase robustness to object orientation
transforms.RandomHorizontalFlip(),
# Random rotation, a common transformation to handle rotated images
transforms.RandomRotation(
20
), # Rotate image by a random angle between -20 and 20 degrees
# Random cropping to simulate random scene zoom
transforms.RandomResizedCrop(
224, scale=(0.8, 1.0)
), # Crop and resize to 224x224
# Random color jitter to make the model robust to lighting changes
# transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
# Random affine transformation (translation, scaling, rotation)
# transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2)),
# Convert to tensor
transforms.ToTensor(),
# Normalize the image with mean and standard deviation for better convergence
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
test_tfm = transforms.Compose(
[
# Resize to the fixed size
transforms.Resize((224, 224)),
transforms.Lambda(lambda x: x.convert("RGB")),
# Convert to tensor
transforms.ToTensor(),
# Normalize the image with mean and standard deviation (same as in training)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def get_dataset():
train_set = DatasetFolder(
folder + "/train/labeled",
loader=lambda x: Image.open(x),
extensions="jpg",
# transform=train_tfm,
)
valid_set = DatasetFolder(
folder + "/val",
loader=lambda x: Image.open(x),
extensions="jpg",
transform=test_tfm,
)
unlabeled_set = DatasetFolder(
folder + "/train/unlabeled",
loader=lambda x: Image.open(x),
extensions="jpg",
# transform=train_tfm,
)
test_set = DatasetFolder(
folder + "/test",
loader=lambda x: Image.open(x),
extensions="jpg",
transform=test_tfm,
)
return train_set, valid_set, unlabeled_set, test_set
def train_collate_fn(batch):
data, labels = zip(*batch)
# data = torch.stack(data)
labels = torch.tensor(labels)
return data, labels
def test_collate_fn(batch):
data, labels = zip(*batch)
data = torch.stack(data)
labels = torch.tensor(labels)
return data, labels
from utils import CustomDataset
def update_dataset(
train_set, unlabeled_set, model, threshold, batch_size=128, num_workers=8
) -> Dataset:
"""
This is the core function to generate pseudo-labels dataets using the given model.
inputs:
- train_set: The labeled training set
- unlabeled_set: The unlabeled dataset to be pseudo-labeled
- model: The trained model to generate pseudo-labels
- threshold: Confidence threshold for pseudo-labeling
- batch_size: Batch size for DataLoader
- num_workers: Number of workers for DataLoader
outputs:
- new_set: The updated dataset with pseudo-labels
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax(dim=-1)
# Create a dataloader for the unlabeled data
unlabeled_loader = DataLoader(
unlabeled_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=train_collate_fn,
)
# List to store the most confident predictions
confident_samples = []
confident_labels = []
with torch.no_grad():
for batch_idx, (images, _) in enumerate(
tqdm(unlabeled_loader, desc="Generating pseudo-labels")
):
# Apply test transform to each image in the batch
new_images = torch.stack([test_tfm(img) for img in images])
# Forward pass through the model
outputs = model(new_images.to(device))
# Apply softmax to get probabilities
probabilities = softmax(outputs)
# Get the maximum probability and corresponding class for each sample
max_probs, pseudo_labels = torch.max(probabilities, dim=1)
# For each sample in the batch, check confidence threshold
for i, (prob, label) in enumerate(zip(max_probs, pseudo_labels)):
# If the prediction is confident enough, add to confident set
if prob.item() > threshold:
# Get the actual index in the unlabeled_set
idx = batch_idx * batch_size + i
if idx < len(unlabeled_set):
img, _ = unlabeled_set[idx]
confident_samples.append(img)
confident_labels.append(int(label.cpu()))
# Create new dataset from the confident predictions
if confident_samples:
pseudo_set = CustomDataset(images=confident_samples, labels=confident_labels)
# Combine with existing labeled data
new_set = ConcatDataset([train_set, pseudo_set])
print(f"Added {len(confident_samples)} pseudo-labeled samples to training set")
else:
print("No confident pseudo-labels found.")
new_set = train_set
return new_set
# %%
from models import * # noqa: F403
from torch import optim
# "cuda" only when GPUs are available.
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialize a model, and put it on the device specified.
# model = Classifier().to(device)
# model = VGG16Classifier(num_classes=NUM_CLASSES).to(device)
# model = ResNet50Classifier(num_classes=NUM_CLASSES).to(device)
model = ResNet101Classifier(num_classes=NUM_CLASSES).to(device)
# model = InceptionV3Classifier(num_classes=NUM_CLASSES).to(device)
# model = ViTLargeClassifier(num_classes=NUM_CLASSES).to(device)
# !!! ONLY USE THIS FOR RESUME TRAINING !!!
model.load_state_dict(torch.load("best_model.pth"))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=2,
threshold=0.001,
threshold_mode="rel",
cooldown=0,
min_lr=0,
eps=1e-08,
)
# %%
# TRAINING
do_semi = True
batch_size = 128
num_workers = 8
train_set, valid_set, unlabeled_set, test_set = get_dataset()
train_loader = DataLoader(
dataset=train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=train_collate_fn,
)
valid_loader = DataLoader(
dataset=valid_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
collate_fn=test_collate_fn,
)
test_loader = DataLoader(
dataset=test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=test_collate_fn,
)
best_valid_loss = float("inf")
start_epoch = 100
epochs = 100
threshold = 0.8
early_stop = False
for epoch in range(start_epoch, epochs):
if do_semi and epoch > epochs // 4 and epoch % 2 == 0:
new_set = update_dataset(
train_set=train_set,
unlabeled_set=unlabeled_set,
model=model,
threshold=threshold,
batch_size=batch_size,
num_workers=num_workers,
)
train_loader = DataLoader(
dataset=new_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=train_collate_fn,
)
# ---------- Training ----------
model.train()
train_loss = []
train_accs = []
for batch in tqdm(train_loader):
imgs, labels = batch
new_images = torch.stack([train_tfm(img) for img in imgs])
logits = model(new_images.to(device))
loss = criterion(logits, labels.to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
train_loss.append(loss.item())
train_accs.append(acc)
train_loss = sum(train_loss) / len(train_loss)
train_acc = sum(train_accs) / len(train_accs)
print(
f"[ Train | {epoch + 1:03d}/{epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}"
)
# ---------- Validation ----------
model.eval()
valid_loss = []
valid_accs = []
for batch in tqdm(valid_loader):
imgs, labels = batch
with torch.no_grad():
logits = model(imgs.to(device))
loss = criterion(logits, labels.to(device))
acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
valid_loss.append(loss.item())
valid_accs.append(acc)
valid_loss = sum(valid_loss) / len(valid_loss)
valid_acc = sum(valid_accs) / len(valid_accs)
print(
f"[ Valid | {epoch + 1:03d}/{epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}"
)
scheduler.step(metrics=valid_loss)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), "best_model.pth")
print(f"Model saved with loss: {valid_loss:.5f}, acc: {valid_acc:.5f}")
elif early_stop:
if epoch > epochs // 2 and valid_loss > best_valid_loss * 1.2:
print("Early stopping")
break
# %%
# torch.save(model.state_dict(), "best_model.pth")
# %%
# LOAD BEST MODEL
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
# %%
# TEST
model.eval()
# Initialize a list to store the predictions.
predictions = []
# Iterate the testing set by batches.
for batch in tqdm(test_loader):
imgs, labels = batch
with torch.no_grad():
logits = model(imgs.to(device))
predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())
# %%
# Save predictions into the file.
with open("predict.csv", "w") as f:
# The first row must be "Id, Category"
f.write("Id,Category\n")
# For the rest of the rows, each image id corresponds to a predicted class.
for i, pred in enumerate(predictions):
f.write(f"{i},{pred}\n")
print("Predictions saved to predict.csv")