shahidul034's picture
Add files using upload-large-folder tool
a16c07b verified
"""
Separate testing/inference script for CIFAR-10 ViT model.
Loads a saved checkpoint, runs inference on the test set,
prints final performance, and saves misclassification analysis.
Also supports an optional transfer-learning experiment with
a pre-trained torchvision ViT model.
Experiment with pre-trained models: Consider fine-tuning pre-trained
Transformer models (e.g., ViT) on your task and evaluate their
performance to understand the impact of transfer learning.
"""
import argparse
from pathlib import Path
from typing import List
from typing import Any, Dict, Tuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from torchvision.models import ViT_B_16_Weights
from c1 import (
CLASS_NAMES,
ViTClassifier,
collect_misclassified,
visualize_misclassified,
)
# ---------------------------------------------------------------------------
# Evaluation and analysis helpers
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate_model(
model: nn.Module,
dataloader: torch.utils.data.DataLoader,
device: torch.device,
) -> Tuple[float, float]:
"""
Compute average loss and accuracy for a model on a dataset split.
Args:
model: Trained model to evaluate.
dataloader: Batches from validation or test split.
device: CPU or CUDA device for inference.
Returns:
(avg_loss, accuracy) aggregated over all samples in `dataloader`.
"""
model.eval()
criterion = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
loss = criterion(logits, labels)
preds = logits.argmax(dim=1)
total_loss += loss.item() * images.size(0)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total
acc = correct / total
return avg_loss, acc
def load_model_from_checkpoint(
checkpoint_path: str,
device: torch.device,
) -> ViTClassifier:
"""
Restore `ViTClassifier` from a saved checkpoint.
The checkpoint is expected to include:
- `model_state_dict` containing learned parameters
- optional `model_config` with architecture hyperparameters
If `model_config` is missing, the function falls back to the training
defaults used in `c1.py`.
"""
checkpoint = torch.load(checkpoint_path, map_location=device)
model_config: Dict[str, Any] = checkpoint.get("model_config", {})
if not model_config:
model_config = {
"image_size": 64,
"patch_size": 4,
"in_channels": 3,
"embed_dim": 256,
"depth": 6,
"num_heads": 8,
"mlp_ratio": 4.0,
"dropout": 0.1,
"num_classes": 10,
}
model = ViTClassifier(**model_config)
state_dict = checkpoint.get("model_state_dict", checkpoint)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
@torch.no_grad()
def collect_predictions(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collect predicted labels and ground-truth labels for full-dataset analysis.
"""
model.eval()
all_preds: List[torch.Tensor] = []
all_labels: List[torch.Tensor] = []
for images, labels in dataloader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
preds = logits.argmax(dim=1)
all_preds.append(preds.cpu())
all_labels.append(labels.cpu())
return torch.cat(all_preds), torch.cat(all_labels)
def build_confusion_matrix(
preds: torch.Tensor,
labels: torch.Tensor,
num_classes: int,
) -> torch.Tensor:
"""
Build confusion matrix where rows=true class and cols=predicted class.
"""
confusion = torch.zeros((num_classes, num_classes), dtype=torch.int64)
for true_label, pred_label in zip(labels, preds):
confusion[int(true_label), int(pred_label)] += 1
return confusion
def format_error_analysis(
preds: torch.Tensor,
labels: torch.Tensor,
class_names: Tuple[str, ...],
) -> str:
"""
Create a readable report with per-class accuracy and top confusion pairs.
"""
num_classes = len(class_names)
confusion = build_confusion_matrix(preds=preds, labels=labels, num_classes=num_classes)
class_totals = confusion.sum(dim=1)
class_correct = confusion.diag()
lines: List[str] = []
lines.append("Per-class accuracy (lower = harder classes):")
per_class_scores = []
for idx, class_name in enumerate(class_names):
total = int(class_totals[idx].item())
correct = int(class_correct[idx].item())
acc = (correct / total) if total > 0 else 0.0
per_class_scores.append((acc, class_name, total))
per_class_scores.sort(key=lambda x: x[0])
for acc, class_name, total in per_class_scores:
lines.append(f" {class_name:<10} | acc={acc * 100:6.2f}% | n={total}")
lines.append("")
lines.append("Top confusion pairs (true -> predicted):")
confusions = []
for true_idx in range(num_classes):
for pred_idx in range(num_classes):
if true_idx == pred_idx:
continue
count = int(confusion[true_idx, pred_idx].item())
if count > 0:
confusions.append((count, true_idx, pred_idx))
confusions.sort(reverse=True, key=lambda x: x[0])
top_k = min(8, len(confusions))
if top_k == 0:
lines.append(" No confusions found (perfect classification).")
else:
for count, true_idx, pred_idx in confusions[:top_k]:
lines.append(
f" {class_names[true_idx]} -> {class_names[pred_idx]}: {count} samples"
)
return "\n".join(lines)
def print_error_analysis(
preds: torch.Tensor,
labels: torch.Tensor,
class_names: Tuple[str, ...],
) -> str:
"""
Print and return error-analysis summary.
"""
report = format_error_analysis(preds=preds, labels=labels, class_names=class_names)
print(f"\n{report}")
return report
def get_imagenet_style_cifar10_dataloaders(
data_root: str = "./data",
batch_size: int = 128,
num_workers: int = 2,
pin_memory: bool = True,
val_ratio: float = 0.1,
seed: int = 42,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
"""
Build CIFAR-10 DataLoaders with ImageNet preprocessing for ViT-B/16.
Why this preprocessing:
- Resize to 224x224 because torchvision ViT-B/16 expects ImageNet-sized input.
- Use ImageNet mean/std so input statistics align with pre-training.
"""
if not 0.0 < val_ratio < 1.0:
raise ValueError("val_ratio must be between 0 and 1.")
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
data_root_path = Path(data_root)
data_root_path.mkdir(parents=True, exist_ok=True)
full_train_dataset = datasets.CIFAR10(
root=str(data_root_path),
train=True,
download=True,
transform=transform,
)
test_dataset = datasets.CIFAR10(
root=str(data_root_path),
train=False,
download=True,
transform=transform,
)
use_pin_memory = pin_memory and torch.cuda.is_available()
val_size = int(len(full_train_dataset) * val_ratio)
train_size = len(full_train_dataset) - val_size
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset = random_split(
full_train_dataset, [train_size, val_size], generator=generator
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=use_pin_memory,
)
return train_loader, val_loader, test_loader
def build_pretrained_vit_classifier(num_classes: int = 10) -> nn.Module:
"""
Load torchvision ViT-B/16 with ImageNet weights and replace classifier head.
"""
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = models.vit_b_16(weights=weights)
in_features = model.heads.head.in_features
model.heads.head = nn.Linear(in_features, num_classes)
return model
def fine_tune_pretrained(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
epochs: int = 2,
lr: float = 1e-4,
weight_decay: float = 1e-4,
) -> None:
"""
Fine-tune a pre-trained ViT on CIFAR-10 and print epoch-level metrics.
Hyperparameters:
- epochs: Number of fine-tuning passes over training data.
- lr: AdamW learning rate for adaptation from ImageNet to CIFAR-10.
- weight_decay: Regularization to reduce overfitting.
"""
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
model.to(device)
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
logits = model(images)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
preds = logits.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_loss = running_loss / total
train_acc = correct / total
val_loss, val_acc = evaluate_model(model=model, dataloader=val_loader, device=device)
print(
f"[Pretrained ViT] Epoch {epoch + 1}/{epochs} | "
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | "
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}%"
)
def build_comparison_report(
baseline_loss: float,
baseline_acc: float,
pretrained_loss: float,
pretrained_acc: float,
) -> str:
"""
Build a compact side-by-side comparison report for baseline vs pre-trained ViT.
"""
acc_delta = (pretrained_acc - baseline_acc) * 100.0
loss_delta = pretrained_loss - baseline_loss
lines = [
"Model comparison (baseline vs transfer learning)",
"-" * 56,
f"{'Model':<28}{'Test Loss':>12}{'Test Acc':>14}",
f"{'Baseline ViT (custom checkpoint)':<28}{baseline_loss:>12.4f}{baseline_acc * 100:>13.2f}%",
f"{'Pre-trained ViT-B/16':<28}{pretrained_loss:>12.4f}{pretrained_acc * 100:>13.2f}%",
"-" * 56,
f"Accuracy gain (pretrained - baseline): {acc_delta:+.2f} percentage points",
f"Loss delta (pretrained - baseline): {loss_delta:+.4f}",
"",
]
return "\n".join(lines)
if __name__ == "__main__":
# -----------------------------------------------------------------------
# CLI arguments and runtime setup
# -----------------------------------------------------------------------
parser = argparse.ArgumentParser(description="Evaluate baseline ViT and run analysis.")
parser.add_argument(
"--checkpoint-path",
type=str,
default="./saved_model/vit_cifar10_best.pt",
help="Path to custom ViT checkpoint.",
)
parser.add_argument(
"--batch-size",
type=int,
default=128,
help="Evaluation batch size.",
)
parser.add_argument(
"--run-pretrained-experiment",
action="store_true",
help="If set, fine-tune a pre-trained ViT-B/16 on CIFAR-10 and compare.",
)
parser.add_argument(
"--results-dir",
type=str,
default="./results",
help="Directory to save plots and analysis reports.",
)
args = parser.parse_args()
checkpoint_path = args.checkpoint_path
results_dir = Path(args.results_dir)
results_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Loading checkpoint: {checkpoint_path}")
print(f"Saving results to: {results_dir}")
# -----------------------------------------------------------------------
# Baseline evaluation using custom ViT checkpoint
# -----------------------------------------------------------------------
# Keep preprocessing identical to training.
from c1 import get_cifar10_dataloaders
_, _, test_loader = get_cifar10_dataloaders(
data_root="./data",
image_size=64,
batch_size=args.batch_size,
val_ratio=0.1,
)
model = load_model_from_checkpoint(checkpoint_path=checkpoint_path, device=device)
test_loss, test_acc = evaluate_model(model=model, dataloader=test_loader, device=device)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc * 100:.2f}%")
preds, labels = collect_predictions(model=model, dataloader=test_loader, device=device)
baseline_analysis = print_error_analysis(
preds=preds, labels=labels, class_names=CLASS_NAMES
)
wrong_samples = collect_misclassified(
model=model,
dataloader=test_loader,
device=device,
max_samples=24,
)
visualize_misclassified(
samples=wrong_samples,
class_names=CLASS_NAMES,
save_path=str(results_dir / "misclassified_examples_test.png"),
)
baseline_report_path = results_dir / "baseline_analysis.txt"
baseline_report_path.write_text(
"\n".join(
[
"Baseline ViT (custom checkpoint) results",
f"Checkpoint: {checkpoint_path}",
f"Test Loss: {test_loss:.4f}",
f"Test Accuracy: {test_acc * 100:.2f}%",
"",
baseline_analysis,
"",
]
),
encoding="utf-8",
)
print(f"Saved baseline analysis to: {baseline_report_path}")
if args.run_pretrained_experiment:
# -------------------------------------------------------------------
# Optional transfer-learning experiment (pre-trained ViT-B/16)
# -------------------------------------------------------------------
# pretrained_epochs controls adaptation budget for ImageNet weights.
# In practice, 2-5 epochs is a quick sanity range for assignment runs.
pretrained_epochs = 2
print("\nRunning transfer-learning experiment with pre-trained ViT-B/16...")
train_loader_pt, val_loader_pt, test_loader_pt = (
get_imagenet_style_cifar10_dataloaders(
data_root="./data",
batch_size=args.batch_size,
val_ratio=0.1,
)
)
pretrained_model = build_pretrained_vit_classifier(num_classes=len(CLASS_NAMES))
fine_tune_pretrained(
model=pretrained_model,
train_loader=train_loader_pt,
val_loader=val_loader_pt,
device=device,
epochs=pretrained_epochs,
lr=1e-4,
weight_decay=1e-4,
)
pt_test_loss, pt_test_acc = evaluate_model(
model=pretrained_model,
dataloader=test_loader_pt,
device=device,
)
print(f"[Pretrained ViT] Test Loss: {pt_test_loss:.4f}")
print(f"[Pretrained ViT] Test Accuracy: {pt_test_acc * 100:.2f}%")
comparison_report = build_comparison_report(
baseline_loss=test_loss,
baseline_acc=test_acc,
pretrained_loss=pt_test_loss,
pretrained_acc=pt_test_acc,
)
print("\n" + comparison_report)
pt_preds, pt_labels = collect_predictions(
model=pretrained_model, dataloader=test_loader_pt, device=device
)
pretrained_analysis = print_error_analysis(
preds=pt_preds, labels=pt_labels, class_names=CLASS_NAMES
)
pt_wrong_samples = collect_misclassified(
model=pretrained_model,
dataloader=test_loader_pt,
device=device,
max_samples=24,
)
visualize_misclassified(
samples=pt_wrong_samples,
class_names=CLASS_NAMES,
save_path=str(results_dir / "misclassified_examples_pretrained_vit.png"),
)
pretrained_report_path = results_dir / "pretrained_vit_analysis.txt"
pretrained_report_path.write_text(
"\n".join(
[
"Pre-trained ViT-B/16 transfer learning results",
f"Fine-tuning epochs: {pretrained_epochs}",
f"Test Loss: {pt_test_loss:.4f}",
f"Test Accuracy: {pt_test_acc * 100:.2f}%",
"",
pretrained_analysis,
"",
]
),
encoding="utf-8",
)
print(f"Saved pre-trained analysis to: {pretrained_report_path}")
comparison_report_path = results_dir / "comparison_report.txt"
comparison_report_path.write_text(comparison_report, encoding="utf-8")
print(f"Saved model comparison report to: {comparison_report_path}")