| | """ |
| | Vision Transformer (ViT) training script for CIFAR-10. |
| | |
| | Reference: |
| | - Dosovitskiy, A., et al. (2021). An Image is Worth 16x16 Words: |
| | Transformers for Image Recognition at Scale. ICLR 2021. |
| | https://arxiv.org/abs/2010.11929 |
| | |
| | This script covers: |
| | 1) Loading CIFAR-10 |
| | 2) Resizing images (default: 64x64) |
| | 3) Normalizing pixel values to [-1, 1] |
| | 4) Creating batched DataLoaders |
| | 5) Building a ViT encoder + classification head |
| | 6) Training with CrossEntropy + AdamW + LR scheduler |
| | 7) Evaluation accuracy + misclassification visualization |
| | """ |
| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader, random_split |
| | from torchvision import datasets, transforms |
| |
|
| | |
| | |
| | |
| | CLASS_NAMES: Tuple[str, ...] = ( |
| | "airplane", |
| | "automobile", |
| | "bird", |
| | "cat", |
| | "deer", |
| | "dog", |
| | "frog", |
| | "horse", |
| | "ship", |
| | "truck", |
| | ) |
| |
|
| |
|
| | def get_cifar10_dataloaders( |
| | data_root: str = "./data", |
| | image_size: int = 64, |
| | batch_size: int = 128, |
| | num_workers: int = 2, |
| | pin_memory: bool = True, |
| | val_ratio: float = 0.1, |
| | seed: int = 42, |
| | ) -> Tuple[DataLoader, DataLoader, DataLoader]: |
| | """ |
| | Build CIFAR-10 train/val/test dataloaders with resize + normalization. |
| | Uses CIFAR-10's official split: |
| | - train=True -> 50,000 images |
| | - train=False -> 10,000 images |
| | |
| | Data source: |
| | - https://www.cs.toronto.edu/~kriz/cifar.html |
| | |
| | Args: |
| | data_root: Directory to download/store CIFAR-10. |
| | image_size: Target image size after resizing (square). |
| | batch_size: Number of samples per batch. |
| | num_workers: Number of subprocesses for data loading. |
| | pin_memory: Pin memory for faster host-to-device transfer on CUDA. |
| | val_ratio: Fraction of official train split reserved for validation. |
| | seed: Random seed for deterministic train/val split. |
| | |
| | Returns: |
| | train_loader, val_loader, test_loader |
| | """ |
| | if not 0.0 < val_ratio < 1.0: |
| | raise ValueError("val_ratio must be between 0 and 1.") |
| |
|
| | transform = transforms.Compose( |
| | [ |
| | transforms.Resize((image_size, image_size)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) |
| | ), |
| | ] |
| | ) |
| |
|
| | data_root_path = Path(data_root) |
| | data_root_path.mkdir(parents=True, exist_ok=True) |
| |
|
| | full_train_dataset = datasets.CIFAR10( |
| | root=str(data_root_path), |
| | train=True, |
| | download=True, |
| | transform=transform, |
| | ) |
| | test_dataset = datasets.CIFAR10( |
| | root=str(data_root_path), |
| | train=False, |
| | download=True, |
| | transform=transform, |
| | ) |
| |
|
| | |
| | use_pin_memory = pin_memory and torch.cuda.is_available() |
| |
|
| | val_size = int(len(full_train_dataset) * val_ratio) |
| | train_size = len(full_train_dataset) - val_size |
| | generator = torch.Generator().manual_seed(seed) |
| | train_dataset, val_dataset = random_split( |
| | full_train_dataset, [train_size, val_size], generator=generator |
| | ) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| | test_loader = DataLoader( |
| | test_dataset, |
| | batch_size=batch_size, |
| | shuffle=False, |
| | num_workers=num_workers, |
| | pin_memory=use_pin_memory, |
| | ) |
| |
|
| | return train_loader, val_loader, test_loader |
| |
|
| |
|
| | class PatchifyEmbedding(nn.Module): |
| | """ |
| | Step 2: |
| | - Divide image into PxP patches |
| | - Flatten each patch |
| | - Project flattened patches to hidden dim D |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | image_size: int = 64, |
| | patch_size: int = 4, |
| | in_channels: int = 3, |
| | embed_dim: int = 256, |
| | ) -> None: |
| | super().__init__() |
| | if image_size % patch_size != 0: |
| | raise ValueError("image_size must be divisible by patch_size.") |
| |
|
| | self.image_size = image_size |
| | self.patch_size = patch_size |
| | self.in_channels = in_channels |
| | self.embed_dim = embed_dim |
| |
|
| | self.num_patches_per_side = image_size // patch_size |
| | self.num_patches = self.num_patches_per_side * self.num_patches_per_side |
| | patch_dim = in_channels * patch_size * patch_size |
| |
|
| | self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size) |
| | self.proj = nn.Linear(patch_dim, embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | x: (B, C, H, W) |
| | returns: (B, N, D), where N=num_patches, D=embed_dim |
| | """ |
| | patches = self.unfold(x) |
| | patches = patches.transpose(1, 2) |
| | embeddings = self.proj(patches) |
| | return embeddings |
| |
|
| |
|
| | class TransformerEncoderBlock(nn.Module): |
| | """ |
| | Step 4 single block: |
| | LayerNorm -> MSA -> residual -> LayerNorm -> MLP -> residual |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_heads: int, |
| | mlp_ratio: float = 4.0, |
| | dropout: float = 0.0, |
| | ) -> None: |
| | super().__init__() |
| | mlp_hidden_dim = int(embed_dim * mlp_ratio) |
| |
|
| | self.norm1 = nn.LayerNorm(embed_dim) |
| | self.attn = nn.MultiheadAttention( |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | dropout=dropout, |
| | batch_first=True, |
| | ) |
| | self.norm2 = nn.LayerNorm(embed_dim) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(embed_dim, mlp_hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | nn.Linear(mlp_hidden_dim, embed_dim), |
| | nn.Dropout(dropout), |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | x_norm = self.norm1(x) |
| | attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) |
| | x = x + attn_out |
| |
|
| | |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class ViTEncoder(nn.Module): |
| | """ |
| | Steps 2-4: |
| | - Patchify + projection |
| | - Learnable CLS token + learnable positional embeddings |
| | - Stacked Transformer encoder blocks |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | image_size: int = 64, |
| | patch_size: int = 4, |
| | in_channels: int = 3, |
| | embed_dim: int = 256, |
| | depth: int = 6, |
| | num_heads: int = 8, |
| | mlp_ratio: float = 4.0, |
| | dropout: float = 0.0, |
| | ) -> None: |
| | super().__init__() |
| | self.patch_embed = PatchifyEmbedding( |
| | image_size=image_size, |
| | patch_size=patch_size, |
| | in_channels=in_channels, |
| | embed_dim=embed_dim, |
| | ) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | |
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
| | self.pos_drop = nn.Dropout(dropout) |
| |
|
| | self.blocks = nn.ModuleList( |
| | [ |
| | TransformerEncoderBlock( |
| | embed_dim=embed_dim, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | dropout=dropout, |
| | ) |
| | for _ in range(depth) |
| | ] |
| | ) |
| | self.norm = nn.LayerNorm(embed_dim) |
| |
|
| | self._init_parameters() |
| |
|
| | def _init_parameters(self) -> None: |
| | nn.init.trunc_normal_(self.cls_token, std=0.02) |
| | nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | x: (B, C, H, W) |
| | returns: (B, D) CLS representation after encoder |
| | """ |
| | x = self.patch_embed(x) |
| | batch_size = x.size(0) |
| |
|
| | cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| | x = self.pos_drop(x + self.pos_embed) |
| |
|
| | for block in self.blocks: |
| | x = block(x) |
| |
|
| | x = self.norm(x) |
| | cls_representation = x[:, 0] |
| | return cls_representation |
| |
|
| |
|
| | class ViTClassifier(nn.Module): |
| | """ |
| | Step 5: |
| | - Extract CLS representation from encoder |
| | - Map to class logits with a Linear layer |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | image_size: int = 64, |
| | patch_size: int = 4, |
| | in_channels: int = 3, |
| | embed_dim: int = 256, |
| | depth: int = 6, |
| | num_heads: int = 8, |
| | mlp_ratio: float = 4.0, |
| | dropout: float = 0.1, |
| | num_classes: int = 10, |
| | ) -> None: |
| | super().__init__() |
| | self.encoder = ViTEncoder( |
| | image_size=image_size, |
| | patch_size=patch_size, |
| | in_channels=in_channels, |
| | embed_dim=embed_dim, |
| | depth=depth, |
| | num_heads=num_heads, |
| | mlp_ratio=mlp_ratio, |
| | dropout=dropout, |
| | ) |
| | self.head = nn.Linear(embed_dim, num_classes) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | cls_features = self.encoder(x) |
| | logits = self.head(cls_features) |
| | return logits |
| |
|
| |
|
| | |
| | |
| | |
| | def train_one_epoch( |
| | model: nn.Module, |
| | dataloader: DataLoader, |
| | criterion: nn.Module, |
| | optimizer: torch.optim.Optimizer, |
| | device: torch.device, |
| | ) -> Tuple[float, float]: |
| | """ |
| | Run one optimization epoch over the training set. |
| | |
| | Args: |
| | model: Classifier to optimize. |
| | dataloader: Training mini-batches. |
| | criterion: Loss function (typically CrossEntropyLoss for CIFAR-10). |
| | optimizer: Parameter optimizer (AdamW in this project). |
| | device: CPU or CUDA device. |
| | |
| | Returns: |
| | (avg_loss, avg_accuracy) over all training samples in this epoch. |
| | """ |
| | model.train() |
| | running_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | for images, labels in dataloader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| |
|
| | optimizer.zero_grad() |
| | logits = model(images) |
| | loss = criterion(logits, labels) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | running_loss += loss.item() * images.size(0) |
| | preds = logits.argmax(dim=1) |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| |
|
| | avg_loss = running_loss / total |
| | avg_acc = correct / total |
| | return avg_loss, avg_acc |
| |
|
| |
|
| | @torch.no_grad() |
| | def evaluate( |
| | model: nn.Module, |
| | dataloader: DataLoader, |
| | criterion: nn.Module, |
| | device: torch.device, |
| | ) -> Tuple[float, float]: |
| | """ |
| | Evaluate model performance without gradient updates. |
| | |
| | Args: |
| | model: Classifier to evaluate. |
| | dataloader: Validation or test mini-batches. |
| | criterion: Loss function used for reporting. |
| | device: CPU or CUDA device. |
| | |
| | Returns: |
| | (avg_loss, avg_accuracy) over all samples from `dataloader`. |
| | """ |
| | model.eval() |
| | running_loss = 0.0 |
| | correct = 0 |
| | total = 0 |
| |
|
| | for images, labels in dataloader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| |
|
| | logits = model(images) |
| | loss = criterion(logits, labels) |
| |
|
| | running_loss += loss.item() * images.size(0) |
| | preds = logits.argmax(dim=1) |
| | correct += (preds == labels).sum().item() |
| | total += labels.size(0) |
| |
|
| | avg_loss = running_loss / total |
| | avg_acc = correct / total |
| | return avg_loss, avg_acc |
| |
|
| |
|
| | def train_model( |
| | model: nn.Module, |
| | train_loader: DataLoader, |
| | val_loader: DataLoader, |
| | device: torch.device, |
| | num_epochs: int = 10, |
| | lr: float = 3e-4, |
| | weight_decay: float = 1e-4, |
| | save_dir: str = "./saved_model", |
| | checkpoint_name: str = "vit_cifar10_best.pt", |
| | model_config: Dict[str, Any] | None = None, |
| | early_stopping_patience: int = 5, |
| | ) -> Tuple[Dict[str, List[float]], str]: |
| | """ |
| | Step 6: |
| | - Loss: CrossEntropy |
| | - Optimizer: AdamW |
| | - LR scheduler: StepLR decay |
| | - Validation each epoch |
| | - Early stopping on validation accuracy |
| | |
| | Hyperparameters: |
| | - num_epochs: Max number of epochs before early stopping. |
| | - lr: Initial learning rate for AdamW updates. |
| | - weight_decay: L2-style regularization term in AdamW. |
| | - early_stopping_patience: Number of non-improving epochs allowed. |
| | This limits overfitting and unnecessary computation. |
| | """ |
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
| | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) |
| |
|
| | history: Dict[str, List[float]] = { |
| | "train_loss": [], |
| | "train_acc": [], |
| | "val_loss": [], |
| | "val_acc": [], |
| | } |
| | best_val_acc = 0.0 |
| | epochs_without_improvement = 0 |
| | save_dir_path = Path(save_dir) |
| | save_dir_path.mkdir(parents=True, exist_ok=True) |
| | best_checkpoint_path = str(save_dir_path / checkpoint_name) |
| |
|
| | model.to(device) |
| |
|
| | for epoch in range(num_epochs): |
| | train_loss, train_acc = train_one_epoch( |
| | model=model, |
| | dataloader=train_loader, |
| | criterion=criterion, |
| | optimizer=optimizer, |
| | device=device, |
| | ) |
| | val_loss, val_acc = evaluate( |
| | model=model, |
| | dataloader=val_loader, |
| | criterion=criterion, |
| | device=device, |
| | ) |
| | scheduler.step() |
| |
|
| | history["train_loss"].append(train_loss) |
| | history["train_acc"].append(train_acc) |
| | history["val_loss"].append(val_loss) |
| | history["val_acc"].append(val_acc) |
| |
|
| | current_lr = optimizer.param_groups[0]["lr"] |
| | print( |
| | f"Epoch [{epoch + 1}/{num_epochs}] | " |
| | f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc * 100:.2f}% | " |
| | f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc * 100:.2f}% | " |
| | f"LR: {current_lr:.6f}" |
| | ) |
| |
|
| | if val_acc > best_val_acc: |
| | best_val_acc = val_acc |
| | epochs_without_improvement = 0 |
| | checkpoint = { |
| | "epoch": epoch + 1, |
| | "best_val_acc": best_val_acc, |
| | "model_state_dict": model.state_dict(), |
| | "model_config": model_config or {}, |
| | } |
| | torch.save(checkpoint, best_checkpoint_path) |
| | print(f"Saved best checkpoint to: {best_checkpoint_path}") |
| | else: |
| | epochs_without_improvement += 1 |
| | if epochs_without_improvement >= early_stopping_patience: |
| | print( |
| | "Early stopping triggered " |
| | f"(no validation improvement for {early_stopping_patience} epochs)." |
| | ) |
| | break |
| |
|
| | final_checkpoint_path = str(save_dir_path / "vit_cifar10_last.pt") |
| | torch.save( |
| | { |
| | "epoch": num_epochs, |
| | "best_val_acc": best_val_acc, |
| | "model_state_dict": model.state_dict(), |
| | "model_config": model_config or {}, |
| | }, |
| | final_checkpoint_path, |
| | ) |
| | print(f"Saved last checkpoint to: {final_checkpoint_path}") |
| |
|
| | return history, best_checkpoint_path |
| |
|
| |
|
| | |
| | |
| | |
| | @torch.no_grad() |
| | def collect_misclassified( |
| | model: nn.Module, |
| | dataloader: DataLoader, |
| | device: torch.device, |
| | max_samples: int = 16, |
| | ) -> List[Tuple[torch.Tensor, int, int]]: |
| | """ |
| | Step 7 (Error analysis helper): |
| | Collect misclassified samples: (image_tensor, true_label, pred_label). |
| | """ |
| | model.eval() |
| | misclassified: List[Tuple[torch.Tensor, int, int]] = [] |
| |
|
| | for images, labels in dataloader: |
| | images = images.to(device) |
| | labels = labels.to(device) |
| | logits = model(images) |
| | preds = logits.argmax(dim=1) |
| | wrong_mask = preds != labels |
| |
|
| | wrong_images = images[wrong_mask] |
| | wrong_labels = labels[wrong_mask] |
| | wrong_preds = preds[wrong_mask] |
| |
|
| | for i in range(wrong_images.size(0)): |
| | misclassified.append( |
| | ( |
| | wrong_images[i].detach().cpu(), |
| | int(wrong_labels[i].item()), |
| | int(wrong_preds[i].item()), |
| | ) |
| | ) |
| | if len(misclassified) >= max_samples: |
| | return misclassified |
| |
|
| | return misclassified |
| |
|
| |
|
| | def denormalize_image(img: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Convert image from normalized [-1, 1] back to [0, 1] for visualization. |
| | """ |
| | return (img * 0.5 + 0.5).clamp(0.0, 1.0) |
| |
|
| |
|
| | def visualize_misclassified( |
| | samples: List[Tuple[torch.Tensor, int, int]], |
| | class_names: Tuple[str, ...], |
| | save_path: str = "misclassified_examples.png", |
| | ) -> None: |
| | """ |
| | Visualize wrongly predicted images and save to disk. |
| | """ |
| | if len(samples) == 0: |
| | print("No misclassified samples to visualize.") |
| | return |
| |
|
| | try: |
| | import matplotlib.pyplot as plt |
| | except ImportError: |
| | print("matplotlib is not installed. Skipping visualization.") |
| | return |
| |
|
| | n = len(samples) |
| | cols = min(4, n) |
| | rows = (n + cols - 1) // cols |
| | fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows)) |
| |
|
| | if rows == 1 and cols == 1: |
| | axes = [axes] |
| | elif rows == 1 or cols == 1: |
| | axes = list(axes) |
| | else: |
| | axes = axes.flatten() |
| |
|
| | for idx, ax in enumerate(axes): |
| | if idx < n: |
| | img, true_lbl, pred_lbl = samples[idx] |
| | img = denormalize_image(img).permute(1, 2, 0).numpy() |
| | ax.imshow(img) |
| | ax.set_title(f"True: {class_names[true_lbl]}\nPred: {class_names[pred_lbl]}") |
| | ax.axis("off") |
| | else: |
| | ax.axis("off") |
| |
|
| | fig.tight_layout() |
| | fig.savefig(save_path, dpi=150) |
| | print(f"Saved misclassified visualization to: {save_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | |
| | |
| | train_loader, val_loader, test_loader = get_cifar10_dataloaders( |
| | data_root="./data", |
| | image_size=64, |
| | batch_size=128, |
| | val_ratio=0.1, |
| | ) |
| |
|
| | train_images, train_labels = next(iter(train_loader)) |
| | val_images, val_labels = next(iter(val_loader)) |
| | test_images, test_labels = next(iter(test_loader)) |
| |
|
| | print(f"Train batch images shape: {train_images.shape}") |
| | print(f"Train batch labels shape: {train_labels.shape}") |
| | print(f"Val batch images shape: {val_images.shape}") |
| | print(f"Val batch labels shape: {val_labels.shape}") |
| | print(f"Test batch images shape: {test_images.shape}") |
| | print(f"Test batch labels shape: {test_labels.shape}") |
| | print(f"Train dataset size: {len(train_loader.dataset)}") |
| | print(f"Val dataset size: {len(val_loader.dataset)}") |
| | print(f"Test dataset size: {len(test_loader.dataset)}") |
| | print( |
| | "Image value range (approx after normalization): " |
| | f"[{train_images.min().item():.3f}, {train_images.max().item():.3f}]" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model_kwargs: Dict[str, Any] = { |
| | "image_size": 64, |
| | "patch_size": 4, |
| | "in_channels": 3, |
| | "embed_dim": 256, |
| | "depth": 6, |
| | "num_heads": 8, |
| | "mlp_ratio": 4.0, |
| | "dropout": 0.1, |
| | "num_classes": 10, |
| | } |
| | model = ViTClassifier( |
| | image_size=64, |
| | patch_size=4, |
| | in_channels=3, |
| | embed_dim=256, |
| | depth=6, |
| | num_heads=8, |
| | mlp_ratio=4.0, |
| | dropout=0.1, |
| | num_classes=10, |
| | ) |
| |
|
| | patch_embeddings = model.encoder.patch_embed(train_images) |
| | cls_features = model.encoder(train_images) |
| | logits = model(train_images) |
| |
|
| | print(f"Patch embeddings shape (B, N, D): {patch_embeddings.shape}") |
| | print(f"CLS feature shape (B, D): {cls_features.shape}") |
| | print(f"Logits shape (B, num_classes): {logits.shape}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | history, best_ckpt_path = train_model( |
| | model=model, |
| | train_loader=train_loader, |
| | val_loader=val_loader, |
| | device=device, |
| | num_epochs=10, |
| | lr=3e-4, |
| | weight_decay=1e-4, |
| | save_dir="./saved_model", |
| | checkpoint_name="vit_cifar10_best.pt", |
| | model_config=model_kwargs, |
| | early_stopping_patience=5, |
| | ) |
| |
|
| | final_val_acc = history["val_acc"][-1] * 100 if history["val_acc"] else 0.0 |
| | print(f"Final validation accuracy: {final_val_acc:.2f}%") |
| | print(f"Best model checkpoint: {best_ckpt_path}") |
| |
|
| | |
| | |
| | |
| | |
| | best_checkpoint = torch.load(best_ckpt_path, map_location=device) |
| | model.load_state_dict(best_checkpoint["model_state_dict"]) |
| | test_criterion = nn.CrossEntropyLoss() |
| | test_loss, test_acc = evaluate( |
| | model=model, |
| | dataloader=test_loader, |
| | criterion=test_criterion, |
| | device=device, |
| | ) |
| | print(f"Final test loss (best checkpoint): {test_loss:.4f}") |
| | print(f"Final test accuracy (best checkpoint): {test_acc * 100:.2f}%") |
| |
|
| | wrong_samples = collect_misclassified( |
| | model=model, |
| | dataloader=test_loader, |
| | device=device, |
| | max_samples=12, |
| | ) |
| | visualize_misclassified( |
| | samples=wrong_samples, |
| | class_names=CLASS_NAMES, |
| | save_path="./results/misclassified_examples.png", |
| | ) |
| |
|