""" Train ViT-B/16 IN-21K with MAE-style patch selection on multiple datasets. Keeps 50% patches via differentiable top-k router + reconstruction loss. Architecture: Image -> Patch Embed + Pos Embed -> Router -> Top-K (STE sigmoid, no Gumbel) -> Lightweight Encoder (2 ViT-B blocks) -> split: (a) Main backbone (10 blocks) -> CLS -> CE Loss (b) MAE Decoder (4 blocks, 512-dim) -> reconstruct discarded -> MSE Loss Usage: python train_patch_selection_mae.py --dataset cifar100 --gpu 6 python train_patch_selection_mae.py --dataset oxford_pets --gpu 3 python train_patch_selection_mae.py --dataset food101 --gpu 5 """ import torch import torch.nn as nn import time import os import sys import math import argparse sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from datasets import get_cifar100_loader, get_oxford_pets_loader, get_food101_loader from models import create_model, patchify DATASETS = { 'cifar100': (get_cifar100_loader, 100, 100), 'oxford_pets': (get_oxford_pets_loader, 37, 100), 'food101': (get_food101_loader, 101, 30), } BASELINE_ACC = { 'cifar100': 91.69, 'oxford_pets': 93.81, 'food101': 91.37, } def get_mse_weight(epoch, total_epochs, start_w=1.0, end_w=0.1): """Cosine anneal MSE weight from start_w to end_w over epochs.""" frac = epoch / max(1, total_epochs - 1) return end_w + 0.5 * (start_w - end_w) * (1 + math.cos(math.pi * frac)) def train_one_epoch(model, loader, criterion, optimizer, device, accum_steps=1, mse_weight=0.5, epoch_time=None): model.train() total_ce_loss = 0 total_mse_loss = 0 correct = 0 total = 0 optimizer.zero_grad() epoch_start = time.time() for batch_idx, (images, targets) in enumerate(loader): images, targets = images.to(device), targets.to(device) # Forward: model returns (logits, pred, keep_mask) during training logits, pred, keep_mask = model(images) # CE loss ce_loss = criterion(logits, targets) # MSE loss on discarded patches (where keep_mask == 0) target_pixels = patchify(images) # (B, N, p*p*C) # Per-patch MSE: (B, N) mse_loss = ((pred - target_pixels) ** 2).mean(dim=-1) # Only on discarded patches discard_mask = 1.0 - keep_mask # 1 = discarded mse_loss = (mse_loss * discard_mask).sum() / discard_mask.sum() # Combined loss loss = ce_loss + mse_weight * mse_loss loss = loss / accum_steps loss.backward() if (batch_idx + 1) % accum_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() total_ce_loss += ce_loss.item() total_mse_loss += mse_loss.item() _, predicted = logits.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() # Handle remaining gradient accumulation if (batch_idx + 1) % accum_steps != 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() epoch_duration = time.time() - epoch_start if epoch_time is not None: epoch_time.append(epoch_duration) n_batches = len(loader) return (total_ce_loss / n_batches, total_mse_loss / n_batches, 100. * correct / total, epoch_duration) @torch.no_grad() def evaluate(model, loader, criterion, device, track_patches=False): model.eval() total_loss = 0 correct = 0 total = 0 kept_patches = [] for images, targets in loader: images, targets = images.to(device), targets.to(device) logits = model(images) # eval returns only logits loss = criterion(logits, targets) total_loss += loss.item() _, predicted = logits.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if track_patches and hasattr(model, '_last_k'): kept_patches.append(model._last_k) result = (total_loss / len(loader), 100. * correct / total) if track_patches and kept_patches: avg_k = sum(kept_patches) / len(kept_patches) avg_n = getattr(model, '_last_n', 196) result = result + (avg_k, avg_n, avg_k / avg_n * 100) return result @torch.no_grad() def compute_efficiency_metrics(model, loader, device): """Measure latency (ms/sample) and throughput (samples/sec).""" model.eval() # Warmup for images, _ in loader: images = images.to(device) _ = model(images) break total_time = 0 total_samples = 0 for images, _ in loader: images = images.to(device) batch_size = images.size(0) if device != 'cpu': torch.cuda.synchronize() start = time.time() _ = model(images) if device != 'cpu': torch.cuda.synchronize() end = time.time() total_time += (end - start) total_samples += batch_size latency = total_time / total_samples * 1000 # ms per sample throughput = total_samples / total_time # samples/sec return latency, throughput def main(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, required=True, choices=list(DATASETS.keys())) parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--keep_ratio', type=float, default=0.5) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--accum', type=int, default=4) parser.add_argument('--lr', type=float, default=3e-5) parser.add_argument('--weight_decay', type=float, default=0.05) parser.add_argument('--label_smoothing', type=float, default=0.1) parser.add_argument('--mse_start', type=float, default=1.0, help='Starting weight for MSE loss') parser.add_argument('--mse_end', type=float, default=0.1, help='Final weight for MSE loss (cosine anneal)') parser.add_argument('--decoder_dim', type=int, default=512, help='MAE decoder embedding dimension') parser.add_argument('--decoder_depth', type=int, default=4, help='MAE decoder transformer depth') parser.add_argument('--router_path', type=str, default=None, help='Load pretrained router weights from this path') parser.add_argument('--image_size', type=int, default=224, help='Input image size (default: 224)') args = parser.parse_args() device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' print(f'Device: {device}', flush=True) if torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(args.gpu)}', flush=True) loader_fn, num_classes, epochs = DATASETS[args.dataset] effective_bs = args.batch_size * args.accum k = int(196 * args.keep_ratio) print(f'Dataset: {args.dataset}', flush=True) print(f' Batch: {args.batch_size}, Accum: {args.accum}, Effective: {effective_bs}', flush=True) print(f' Epochs: {epochs}, LR: {args.lr}, WD: {args.weight_decay}', flush=True) print(f' Label smoothing: {args.label_smoothing}', flush=True) print(f' Keep ratio: {args.keep_ratio} ({k}/196 patches)', flush=True) print(f' MSE weight: {args.mse_start} -> {args.mse_end} (cosine anneal)', flush=True) print(f' Decoder: dim={args.decoder_dim}, depth={args.decoder_depth}', flush=True) print(f' Baseline Acc: {BASELINE_ACC[args.dataset]:.2f}%', flush=True) # Data result = loader_fn(batch_size=args.batch_size, data_dir='./data', num_workers=4, image_size=args.image_size) if len(result) == 4: train_loader, val_loader, test_loader, n_cls = result print(f' Train: {len(train_loader.dataset)}, Val: {len(val_loader.dataset)}, ' f'Test: {len(test_loader.dataset)}, Classes: {n_cls}', flush=True) else: train_loader, test_loader, n_cls = result val_loader = test_loader print(f' Train: {len(train_loader.dataset)}, Test/Val: {len(test_loader.dataset)}, ' f'Classes: {n_cls}', flush=True) # Model print('Creating MAE Patch Selection ViT-B/16...', flush=True) model = create_model( model_name='mae_patch_selection_vit_b16', num_classes=n_cls, keep_ratio=args.keep_ratio, pretrained=True, decoder_embed_dim=args.decoder_dim, decoder_depth=args.decoder_depth, img_size=args.image_size, ) model = model.to(device) router_params = sum(p.numel() for p in model.router.parameters()) / 1e3 decoder_params = sum(p.numel() for p in model.decoder.parameters()) / 1e6 total_params = sum(p.numel() for p in model.parameters()) / 1e6 print(f' Total params: {total_params:.2f}M', flush=True) print(f' Router params: {router_params:.1f}K', flush=True) print(f' Decoder params: {decoder_params:.2f}M', flush=True) # Load pretrained router if specified (from attention distillation) if args.router_path is not None: print(f'Loading pretrained router from {args.router_path}...', flush=True) ckpt = torch.load(args.router_path, map_location=device) model.router.load_state_dict(ckpt['router_state_dict']) model.router.requires_grad_(True) # still fine-tune during training print(' -> Loaded (router will be fine-tuned during training)', flush=True) # Training setup criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) best_val_acc = 0 best_epoch = -1 test_acc_at_best = 0 epoch_times = [] save_dir = f'./checkpoints/{args.dataset}_mae_patchsel_b16_keep{int(args.keep_ratio*100)}' if args.router_path is not None: save_dir += '_distill' os.makedirs(save_dir, exist_ok=True) print(f'\n=== MAE Patch Selection ViT-B/16 on {args.dataset} ({epochs} epochs) ===\n', flush=True) for epoch in range(epochs): # Anneal MSE weight mse_w = get_mse_weight(epoch, epochs, args.mse_start, args.mse_end) train_ce, train_mse, train_acc, ep_time = train_one_epoch( model, train_loader, criterion, optimizer, device, args.accum, mse_w, epoch_times) scheduler.step() val_loss, val_acc, avg_k, avg_n, keep_pct = evaluate( model, val_loader, criterion, device, track_patches=True) print(f'Epoch {epoch+1}/{epochs} ({ep_time:.1f}s, MSE w={mse_w:.3f})', flush=True) print(f' Train CE: {train_ce:.4f}, MSE: {train_mse:.6f}, Acc: {train_acc:.2f}%', flush=True) print(f' Val Loss: {val_loss:.4f}, Acc: {val_acc:.2f}%', flush=True) print(f' Keep: {int(avg_k)}/{int(avg_n)} patches ({keep_pct:.1f}%)', flush=True) print(f' LR: {optimizer.param_groups[0]["lr"]:.6f}', flush=True) if val_acc > best_val_acc: best_val_acc = val_acc best_epoch = epoch test_loss_at_best, test_acc_at_best = evaluate(model, test_loader, criterion, device) os.makedirs(save_dir, exist_ok=True) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'val_acc': val_acc, 'test_acc': test_acc_at_best, }, f'{save_dir}/best_model.pth') print(f' -> Saved best model (Val Acc: {best_val_acc:.2f}%, ' f'Test Acc: {test_acc_at_best:.2f}%)', flush=True) print(flush=True) # Final evaluation print('=== Evaluating best model on test set ===', flush=True) ckpt = torch.load(f'{save_dir}/best_model.pth', map_location=device) model.load_state_dict(ckpt['model_state_dict']) test_loss, test_acc = evaluate(model, test_loader, criterion, device) # Efficiency print('=== Efficiency Metrics ===', flush=True) latency, throughput = compute_efficiency_metrics(model, test_loader, device) avg_epoch_time = sum(epoch_times) / len(epoch_times) baseline_acc = BASELINE_ACC[args.dataset] acc_diff = test_acc - baseline_acc print(f'\n========== Final Results ({args.dataset}) ==========', flush=True) print(f' Best Val Epoch: {best_epoch+1}', flush=True) print(f' Best Val Acc: {best_val_acc:.2f}%', flush=True) print(f' Test Acc: {test_acc:.2f}%', flush=True) print(f' Baseline Acc: {baseline_acc:.2f}%', flush=True) print(f' Acc Diff: {acc_diff:+.2f}%', flush=True) print(f' Keep Ratio: {args.keep_ratio} ({k}/196 patches)', flush=True) print(f' --------------------------------------------', flush=True) print(f' Avg Epoch Time: {avg_epoch_time:.1f}s', flush=True) print(f' Latency: {latency:.2f} ms/sample', flush=True) print(f' Throughput: {throughput:.2f} samples/sec', flush=True) print(f' Baseline FLOPs: 33.85G (ViT-B/16 full)', flush=True) print(f'===============================================\n', flush=True) if __name__ == '__main__': main()