import torch import torch.nn as nn import torch.nn.functional as F import timm import math import numpy as np def create_vit_small(num_classes=100, pretrained=True, drop_path_rate=0.0): """ViT-S/16 baseline - all patches, no pruning.""" model = timm.create_model( 'vit_small_patch16_224.augreg_in1k', pretrained=pretrained, num_classes=num_classes, drop_path_rate=drop_path_rate, ) return model def create_swin_tiny(num_classes=100, pretrained=True): """Swin-Tiny baseline.""" model = timm.create_model( 'swin_tiny_patch4_window7_224.ms_in1k', pretrained=pretrained, num_classes=num_classes, ) return model def _replace_patch_embed(model, patch_size, patch_stride, img_size=224): """Replace the patch embedding layer for custom patch_size/stride.""" old_embed = model.patch_embed in_chans = old_embed.proj.in_channels embed_dim = model.embed_dim new_conv = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride) # Compute new number of patches h = w = img_size h_out = (h - patch_size) // patch_stride + 1 w_out = (w - patch_size) // patch_stride + 1 num_patches = h_out * w_out # Initialize new conv from old (interpolate if sizes differ) if old_embed.proj.weight.shape[2] == patch_size == patch_stride == old_embed.proj.stride[0]: new_conv.weight.data.copy_(old_embed.proj.weight.data) if old_embed.proj.bias is not None: new_conv.bias.data.copy_(old_embed.proj.bias.data) model.patch_embed.proj = new_conv model.patch_embed.num_patches = num_patches model.patch_embed.grid_size = (h_out, w_out) # New positional embedding old_pos = model.pos_embed # (1, N+1, D) cls_pos = old_pos[:, 0:1, :] patch_pos = old_pos[:, 1:, :] # Interpolate to new number of patches old_h = int(math.sqrt(patch_pos.shape[1])) patch_pos_3d = patch_pos.transpose(1, 2).reshape(1, -1, old_h, old_h) new_patch_pos = F.interpolate(patch_pos_3d, size=(h_out, w_out), mode='bicubic', align_corners=False) new_patch_pos = new_patch_pos.reshape(1, -1, h_out * w_out).transpose(1, 2) model.pos_embed = nn.Parameter(torch.cat([cls_pos, new_patch_pos], dim=1)) return model class GumbelSelection(nn.Module): """Differentiable patch selection supporting top-k and adaptive modes.""" def __init__(self, num_patches, keep_ratio=0.5, selection_mode='topk', adaptive_alpha=0.5, min_keep=16): super().__init__() self.num_patches = num_patches self.num_keep = max(1, int(num_patches * keep_ratio)) self.selection_mode = selection_mode self.adaptive_alpha = adaptive_alpha self.min_keep = min_keep def forward(self, scores): """ Args: scores: (B, N) importance scores per patch Returns: selected_indices: (B, K) indices mask: (B, N) soft/hard mask k: int (number kept) """ B, N = scores.shape if self.selection_mode == 'topk': return self._topk_selection(scores, B, N) elif self.selection_mode == 'adaptive': return self._adaptive_selection(scores, B, N) else: raise ValueError(f'Unknown selection_mode: {self.selection_mode}') def _topk_selection(self, scores, B, N): k = min(self.num_keep, N) if self.training: gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) noisy_scores = scores + gumbel_noise _, indices = torch.topk(noisy_scores, k, dim=1) hard_mask = torch.zeros_like(scores) hard_mask.scatter_(1, indices, 1.0) soft_mask = F.softmax(scores / 0.1, dim=1) soft_mask = soft_mask / soft_mask.sum(dim=1, keepdim=True) * N mask = hard_mask.detach() + soft_mask - soft_mask.detach() return indices, mask, k else: _, indices = torch.topk(scores, k, dim=1) mask = torch.zeros_like(scores) mask.scatter_(1, indices, 1.0) return indices, mask, k def _adaptive_selection(self, scores, B, N): """Adaptive: keep patches depending on score distribution. Threshold = mean(scores) + alpha * std(scores). Each image determines its own keep count; we pad to batch max. """ if self.training: # Training: use concrete distribution (Gumbel-Sigmoid) threshold = scores.mean(dim=1, keepdim=True) + \ self.adaptive_alpha * scores.std(dim=1, keepdim=True) # Gumbel-Sigmoid for differentiable per-patch decisions gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) logits = (scores - threshold) / 0.1 + gumbel_noise probs = torch.sigmoid(logits) # keep probability per patch # STE: hard decisions in forward, soft gradients in backward hard = (probs > 0.5).float() mask = hard.detach() + probs - probs.detach() # Compute per-image K (for gathering) k_per_image = mask.sum(dim=1).long() # (B,) k_per_image = torch.clamp(k_per_image, min=self.min_keep, max=N) k = k_per_image.max().item() # Expand/collapse for variable K: pad each to k # Simple: just use per-image k for topk from scores # This keeps the gradient flow through gumbel scores indices_list = [] for i in range(B): ki = k_per_image[i].item() _, idx_i = torch.topk(scores[i], ki) if ki < k: idx_i = F.pad(idx_i, (0, k - ki), value=0) indices_list.append(idx_i) indices = torch.stack(indices_list) return indices, mask, k else: # Inference: per-image threshold, pad to max k_max = 0 indices_list = [] for i in range(B): thresh = scores[i].mean() + self.adaptive_alpha * scores[i].std() keep_mask = scores[i] > thresh ki = max(self.min_keep, keep_mask.sum().item()) ki = min(ki, N) k_max = max(k_max, ki) _, idx_i = torch.topk(scores[i], ki) indices_list.append((ki, idx_i)) # Pad all to k_max pad_indices = [] for ki, idx_i in indices_list: if ki < k_max: idx_i = F.pad(idx_i, (0, k_max - ki), value=0) pad_indices.append(idx_i) indices = torch.stack(pad_indices) mask = torch.zeros(B, N, device=scores.device) for i in range(B): mask[i, indices[i][:indices_list[i][0]]] = 1.0 return indices, mask, k_max class SemanticRouter(nn.Module): """ Lightweight router for scoring patch importance. Uses per-patch MLP + 1-layer self-attention for context-aware scoring. """ def __init__(self, embed_dim=384, hidden_dim=192, num_heads=4): super().__init__() self.per_patch_mlp = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), ) self.cross_attn = nn.MultiheadAttention( hidden_dim, num_heads, batch_first=True, dropout=0.1 ) self.ln = nn.LayerNorm(hidden_dim) self.score_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(), nn.Linear(hidden_dim // 2, 1), ) def forward(self, x): """ Args: x: (B, N, D) patch tokens Returns: scores: (B, N) raw importance scores """ B, N, D = x.shape # Per-patch MLP feat = self.per_patch_mlp(x) # (B, N, H) # Self-attention across patches attn_feat, _ = self.cross_attn(feat, feat, feat) feat = self.ln(feat + attn_feat) # Score each patch scores = self.score_head(feat).squeeze(-1) # (B, N) return scores class PatchSelectionViT(nn.Module): """ ViT-S/16 or ViT-B/16 with pre-tokenization patch selection. Supports: - topk: fixed K patches per image (based on keep_ratio) - adaptive: threshold-based, per-image variable K - Custom patch_size and patch_stride """ def __init__(self, num_classes=100, keep_ratio=0.5, pretrained=True, selection_mode='topk', adaptive_alpha=0.5, patch_size=16, patch_stride=None, drop_path_rate=0.0, backbone_name='vit_small_patch16_224.augreg_in1k'): super().__init__() self.keep_ratio = keep_ratio self.selection_mode = selection_mode self.adaptive_alpha = adaptive_alpha self.patch_size = patch_size self.patch_stride = patch_stride if patch_stride is not None else patch_size # Load pretrained backbone (S/16 or B/16 depending on backbone_name) self.backbone = timm.create_model( backbone_name, pretrained=pretrained, num_classes=num_classes, drop_path_rate=drop_path_rate, ) # Override patch embedding if custom size if patch_size != 16 or self.patch_stride != 16: _replace_patch_embed(self.backbone, patch_size, self.patch_stride) self.patch_embed = self.backbone.patch_embed self.cls_token = self.backbone.cls_token self.pos_drop = self.backbone.pos_drop self.pos_embed = self.backbone.pos_embed embed_dim = self.backbone.embed_dim num_patches = self.patch_embed.num_patches # Selection mechanism self.selection = GumbelSelection( num_patches=num_patches, keep_ratio=keep_ratio, selection_mode=selection_mode, adaptive_alpha=adaptive_alpha, min_keep=16, ) self._last_k = num_patches self._last_n = num_patches # Router (hidden_dim scales with embed_dim) router_hidden_dim = max(192, embed_dim // 2) router_num_heads = max(4, embed_dim // 64) self.router = SemanticRouter( embed_dim=embed_dim, hidden_dim=router_hidden_dim, num_heads=router_num_heads, ) self.blocks = self.backbone.blocks self.norm = self.backbone.norm self.head = self.backbone.head def load_mae_pretrained(self, mae_encoder_path): """Load encoder weights from MAE pretrained checkpoint.""" ckpt = torch.load(mae_encoder_path, map_location='cpu') # Support both full model and encoder-only checkpoints if 'encoder_state_dict' in ckpt: state_dict = ckpt['encoder_state_dict'] else: state_dict = ckpt msg = self.backbone.load_state_dict(state_dict, strict=False) print(f'[PatchSelectionViT] Loaded MAE encoder: {msg}') def forward(self, x): B = x.shape[0] # 1. Patch Embedding x = self.patch_embed(x) # (B, N, D) N = x.shape[1] # 2. Score patches with semantic router scores = self.router(x) # (B, N) # 3. Select patches selected_indices, mask, k = self.selection(scores) k = min(k, N) self._last_k = k # track for logging self._last_n = N # Gather selected patches with mask gradient (STE: hard forward, soft backward) batch_indices = torch.arange(B, device=x.device).unsqueeze(1).expand(-1, k) selected_patches = (x * mask.unsqueeze(-1))[batch_indices, selected_indices] # (B, K, D) # 4. Add [CLS] token cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D) x = torch.cat([cls_tokens, selected_patches], dim=1) # (B, K+1, D) # Add positional embeddings cls_pos = self.pos_embed[:, 0:1, :] # (1, 1, D) all_patch_pos = self.pos_embed[:, 1:, :].expand(B, -1, -1) # (B, N, D) selected_pos = all_patch_pos[batch_indices, selected_indices] # (B, K, D) pos_embed = torch.cat([cls_pos.expand(B, -1, -1), selected_pos], dim=1) x = x + pos_embed x = self.pos_drop(x) # 5. Transformer blocks for block in self.blocks: x = block(x) x = self.norm(x) # 6. Classification head (use [CLS] token) x = x[:, 0] x = self.head(x) return x class RandomPruneViT(nn.Module): """ViT-S/16 with random patch pruning - lower bound baseline.""" def __init__(self, num_classes=100, keep_ratio=0.5, pretrained=True, patch_size=16, patch_stride=None, drop_path_rate=0.0): super().__init__() self.keep_ratio = keep_ratio self.patch_stride = patch_stride if patch_stride is not None else patch_size self.backbone = timm.create_model( 'vit_small_patch16_224.augreg_in1k', pretrained=pretrained, num_classes=num_classes, drop_path_rate=drop_path_rate, ) if patch_size != 16 or self.patch_stride != 16: _replace_patch_embed(self.backbone, patch_size, self.patch_stride) self.patch_embed = self.backbone.patch_embed self.cls_token = self.backbone.cls_token self.pos_drop = self.backbone.pos_drop self.pos_embed = self.backbone.pos_embed num_patches = self.patch_embed.num_patches self.num_keep = max(1, int(num_patches * keep_ratio)) self.blocks = self.backbone.blocks self.norm = self.backbone.norm self.head = self.backbone.head def forward(self, x): B = x.shape[0] N = self.patch_embed.num_patches k = self.num_keep x = self.patch_embed(x) # Random selection if self.training or k < N: indices = torch.randperm(N, device=x.device)[:k] indices = indices.unsqueeze(0).expand(B, -1) else: indices = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1) batch_indices = torch.arange(B, device=x.device).unsqueeze(1).expand(-1, k) x = x[batch_indices, indices] cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) cls_pos = self.pos_embed[:, 0:1, :] all_patch_pos = self.pos_embed[:, 1:, :].expand(B, -1, -1) selected_pos = all_patch_pos[batch_indices, indices] pos_embed = torch.cat([cls_pos.expand(B, -1, -1), selected_pos], dim=1) x = x + pos_embed x = self.pos_drop(x) for block in self.blocks: x = block(x) x = self.norm(x) x = x[:, 0] x = self.head(x) return x # ========== MAE-style pretraining ========== def random_masking(x, mask_ratio=0.75): """ Randomly mask patches. Encoder only sees unmasked ones. Args: x: (B, N, D) patch tokens mask_ratio: fraction to mask Returns: x_masked: (B, N_keep, D) visible patches mask: (B, N) 0/1 mask (1 = masked) ids_restore: (B, N) indices to restore original order """ B, N, D = x.shape n_keep = int(N * (1 - mask_ratio)) # Random shuffle indices ids_shuffle = torch.rand(B, N, device=x.device).argsort(dim=1) ids_restore = ids_shuffle.argsort(dim=1) # Keep the first n_keep, mask the rest ids_keep = ids_shuffle[:, :n_keep] batch_idx = torch.arange(B, device=x.device).unsqueeze(1) x_masked = x[batch_idx, ids_keep] # (B, n_keep, D) # Mask: 1 = masked, 0 = visible mask = torch.ones(B, N, device=x.device) mask[batch_idx, ids_keep] = 0 return x_masked, mask, ids_restore def patchify(images, patch_size=16): """Convert images to patch pixels. images: (B, 3, H, W) Returns: (B, N, patch_size*patch_size*3) """ B, C, H, W = images.shape p = patch_size assert H % p == 0 and W % p == 0 h = H // p w = W // p x = images.reshape(B, C, h, p, w, p) x = x.permute(0, 2, 4, 3, 5, 1).contiguous() x = x.reshape(B, h * w, p * p * C) return x def unpatchify(x, patch_size=16, channels=3): """Convert patch pixels back to images. x: (B, N, p*p*C) Returns: (B, C, H, W) """ B, N, _ = x.shape p = patch_size h = w = int(N ** 0.5) assert h * w == N x = x.reshape(B, h, w, p, p, channels) x = x.permute(0, 5, 1, 3, 2, 4).contiguous() x = x.reshape(B, channels, h * p, w * p) return x class MAEDecoder(nn.Module): """Lightweight decoder for MAE pretraining.""" def __init__(self, embed_dim=384, decoder_embed_dim=192, decoder_depth=4, decoder_num_heads=6, num_patches=196, patch_size=16, in_chans=3): super().__init__() self.num_patches = num_patches self.patch_size = patch_size # Project encoder output to decoder dimension self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim) # Mask token shared across all masked positions self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # Positional embeddings for decoder (all patches) self.decoder_pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, decoder_embed_dim)) # +1 for cls # Decoder transformer blocks decoder_layer = nn.TransformerEncoderLayer( d_model=decoder_embed_dim, nhead=decoder_num_heads, dim_feedforward=decoder_embed_dim * 4, dropout=0.0, activation='gelu', batch_first=True, norm_first=True, ) self.decoder_blocks = nn.TransformerEncoder( decoder_layer, num_layers=decoder_depth ) # Prediction head self.decoder_norm = nn.LayerNorm(decoder_embed_dim) self.decoder_pred = nn.Linear( decoder_embed_dim, patch_size * patch_size * in_chans) def forward(self, x, ids_restore): """ Args: x: (B, n_keep, D) encoder output (after norm) ids_restore: (B, N) indices to restore original order Returns: pred: (B, N, p*p*C) pixel predictions """ B, n_keep, D = x.shape N = self.num_patches # Project to decoder dim x = self.decoder_embed(x) # (B, n_keep, decoder_dim) # Append mask tokens n_mask = N - n_keep mask_tokens = self.mask_token.repeat(B, n_mask, 1) # (B, n_mask, decoder_dim) x = torch.cat([x, mask_tokens], dim=1) # (B, N, decoder_dim) # Restore original order x = torch.gather(x, 1, ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # Add positional embedding x = x + self.decoder_pos_embed[:, 1:, :] # skip cls pos # Decoder transformer x = self.decoder_blocks(x) # Prediction x = self.decoder_norm(x) pred = self.decoder_pred(x) # (B, N, p*p*C) return pred class MAEViT(nn.Module): """ MAE-style pretraining for ViT-S/16. Encoder processes visible patches, decoder reconstructs masked ones. After pretraining, the encoder backbone can be used for downstream tasks. """ def __init__(self, num_classes=100, mask_ratio=0.75, decoder_depth=4, decoder_embed_dim=192, pretrained=True): super().__init__() self.mask_ratio = mask_ratio # Encoder: ViT-S/16 self.backbone = timm.create_model( 'vit_small_patch16_224.augreg_in1k', pretrained=pretrained, num_classes=num_classes, ) self.patch_embed = self.backbone.patch_embed self.cls_token = self.backbone.cls_token self.pos_drop = self.backbone.pos_drop self.pos_embed = self.backbone.pos_embed embed_dim = self.backbone.embed_dim num_patches = self.patch_embed.num_patches patch_size = 16 # ViT-S/16 self.blocks = self.backbone.blocks self.encoder_norm = self.backbone.norm # Decoder self.decoder = MAEDecoder( embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, decoder_num_heads=6, num_patches=num_patches, patch_size=patch_size, in_chans=3, ) def forward(self, x): """ Args: x: (B, 3, H, W) input images Returns: loss: MSE reconstruction loss pred: (B, N, p*p*3) pixel predictions mask: (B, N) 1= masked, 0=visible """ B, C, H, W = x.shape # Save original for loss computation images = x # Patch embedding x = self.patch_embed(x) # (B, N, D) N = x.shape[1] # Add positional embedding x = x + self.pos_embed[:, 1:, :] x = self.pos_drop(x) # Random masking x_masked, mask, ids_restore = random_masking(x, self.mask_ratio) # x_masked: (B, n_keep, D), mask: (B, N), ids_restore: (B, N) # Add cls token to masked sequence cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D) cls_pos = self.pos_embed[:, 0:1, :] x_masked = torch.cat([cls_tokens + cls_pos, x_masked], dim=1) # (B, 1+n_keep, D) # Encoder for block in self.blocks: x_masked = block(x_masked) enc_out = self.encoder_norm(x_masked) # Remove cls token for decoder enc_patches = enc_out[:, 1:, :] # (B, n_keep, D) # Decoder prediction pred = self.decoder(enc_patches, ids_restore) # (B, N, p*p*3) # Compute loss on masked patches only target = patchify(images, patch_size=16) # (B, N, p*p*3) # Per-patch normalization target_mean = target.mean(dim=-1, keepdim=True) target_var = target.var(dim=-1, keepdim=True) + 1e-6 target_norm = (target - target_mean) / target_var.sqrt() # Loss: only on masked patches loss = (pred - target_norm) ** 2 loss = loss.mean(dim=-1) # (B, N) loss = (loss * mask).sum() / mask.sum() # average over masked patches return loss, pred, mask def get_encoder(self): """Return the encoder backbone for downstream tasks.""" return self.backbone class MAEPatchSelectionViT(nn.Module): """ ViT-B/16 with MAE-style patch selection. Architecture: Image -> Patch Embed + Pos Embed -> Router (MLP scores) -> Differentiable Top-K (Sigmoid STE, no Gumbel) -> keep K patches -> Lightweight Encoder (first 2 ViT-B blocks, pretrained) -> split: (a) Main backbone (remaining 10 blocks) -> CLS head -> CE Loss (b) MAE Decoder (4 blocks, 512-dim) -> reconstruct discarded -> MSE Loss Training forward returns (logits, pred_pixels, keep_mask). Eval forward returns logits only. """ def __init__(self, num_classes=100, keep_ratio=0.5, pretrained=True, drop_path_rate=0.0, decoder_embed_dim=512, decoder_depth=4, img_size=224): super().__init__() self.keep_ratio = keep_ratio self.selection_temperature = 1.0 self.img_size = img_size # Load pretrained ViT-B/16 backbone (source of all weights) backbone = timm.create_model( 'vit_base_patch16_224.augreg_in21k', pretrained=pretrained, num_classes=num_classes, drop_path_rate=drop_path_rate, img_size=img_size, ) self.patch_embed = backbone.patch_embed self.cls_token = backbone.cls_token self.pos_embed = backbone.pos_embed self.pos_drop = backbone.pos_drop embed_dim = backbone.embed_dim # 768 num_patches = self.patch_embed.num_patches # 196 self.num_patches = num_patches self.num_keep = max(1, int(num_patches * keep_ratio)) # Router: simple MLP (no self-attention, ~300K params) self.router = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.LayerNorm(embed_dim // 2), nn.GELU(), nn.Linear(embed_dim // 2, 1), ) # Lightweight encoder: first 2 ViT-B blocks (pretrained weights) blocks = list(backbone.blocks) self.lightweight_encoder = nn.ModuleList(blocks[:2]) # Main backbone: remaining 10 blocks self.main_blocks = nn.ModuleList(blocks[2:]) self.norm = backbone.norm self.head = backbone.head # MAE Decoder (random init, trained from scratch) self.decoder = MAEDecoder( embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, decoder_num_heads=8, num_patches=num_patches, patch_size=16, in_chans=3, ) # Clean up to avoid duplicate parameter ownership del backbone # Tracking self._last_k = num_patches self._last_n = num_patches def forward(self, x): """ Args: x: (B, 3, H, W) input images Returns: Training: (logits, pred, keep_mask) logits: (B, num_classes) pred: (B, N, p*p*C) decoder pixel predictions (all patches) keep_mask: (B, N) 1=selected/kept, 0=discarded Eval: logits (B, num_classes) """ B = x.shape[0] # 1. Patch embedding + positional embedding (on ALL patches, before selection) x = self.patch_embed(x) # (B, N, D) N = x.shape[1] x = x + self.pos_embed[:, 1:, :] # pos-aware features for router x = self.pos_drop(x) # 2. Router scores scores = self.router(x).squeeze(-1) # (B, N) # 3. Differentiable Top-K selection (no Gumbel noise) k = min(self.num_keep, N) _, indices = torch.topk(scores, k, dim=1) # (B, K) # Build hard mask: 1 = selected hard_mask = torch.zeros(B, N, device=x.device) hard_mask.scatter_(1, indices, 1.0) if self.training: # Soft mask via sigmoid STE threshold = scores.topk(k, dim=1)[0][:, -1:] # (B, 1) soft_mask = torch.sigmoid( (scores - threshold) / self.selection_temperature) keep_mask = hard_mask.detach() + soft_mask - soft_mask.detach() else: keep_mask = hard_mask self._last_k = k self._last_n = N # Gather selected patches batch_idx = torch.arange(B, device=x.device).unsqueeze(1).expand(-1, k) selected = x[batch_idx, indices] # (B, K, D) # 4. Lightweight encoder (2 ViT-B blocks) — WITH CLS token cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, D) h = torch.cat([cls_tokens, selected], dim=1) # (B, K+1, D) for block in self.lightweight_encoder: h = block(h) # Save patch features for decoder (before main backbone changes them) h_patches = h[:, 1:, :] # (B, K, D), remove CLS # 5-A. Main backbone for classification (CLS goes through all 12 blocks now) for block in self.main_blocks: h = block(h) h = self.norm(h) logits = self.head(h[:, 0]) # Eval: no decoder needed if not self.training: return logits # 5-B. MAE Decoder for reconstruction (training only) all_idx = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1) # Discarded indices (not in selected) discard_mask = torch.ones(B, N, device=x.device, dtype=torch.bool) discard_mask.scatter_(1, indices, False) discarded = all_idx[discard_mask].reshape(B, -1) # (B, N-K) # [selected | discarded] -> argsort -> ids_restore ids_sort = torch.cat([indices, discarded], dim=1) # (B, N) ids_restore = ids_sort.argsort(dim=1) # (B, N) # Decoder: reconstruct from lightweight encoder output (CLS removed) pred = self.decoder(h_patches, ids_restore) # (B, N, p*p*C) return logits, pred, hard_mask.detach() def create_model(model_name, num_classes=100, keep_ratio=0.5, pretrained=True, selection_mode='topk', adaptive_alpha=0.5, patch_size=16, patch_stride=None, mask_ratio=0.75, decoder_depth=4, decoder_embed_dim=192, drop_path_rate=0.0, img_size=224): """Factory function for all models.""" if model_name == 'vit_small': return create_vit_small(num_classes, pretrained, drop_path_rate=drop_path_rate) elif model_name == 'swin_tiny': return create_swin_tiny(num_classes, pretrained) elif model_name == 'patch_selection_vit': return PatchSelectionViT( num_classes, keep_ratio, pretrained, selection_mode=selection_mode, adaptive_alpha=adaptive_alpha, patch_size=patch_size, patch_stride=patch_stride, drop_path_rate=drop_path_rate, backbone_name='vit_small_patch16_224.augreg_in1k', ) elif model_name == 'patch_selection_vit_b16': return PatchSelectionViT( num_classes, keep_ratio, pretrained, selection_mode=selection_mode, adaptive_alpha=adaptive_alpha, patch_size=patch_size, patch_stride=patch_stride, drop_path_rate=drop_path_rate, backbone_name='vit_base_patch16_224.augreg_in21k', ) elif model_name == 'patch_selection_vit_b16_in1k': return PatchSelectionViT( num_classes, keep_ratio, pretrained, selection_mode=selection_mode, adaptive_alpha=adaptive_alpha, patch_size=patch_size, patch_stride=patch_stride, drop_path_rate=drop_path_rate, backbone_name='vit_base_patch16_224.augreg_in1k', ) elif model_name == 'random_prune_vit': return RandomPruneViT( num_classes, keep_ratio, pretrained, patch_size=patch_size, patch_stride=patch_stride, drop_path_rate=drop_path_rate, ) elif model_name == 'mae_vit': return MAEViT( num_classes=num_classes, mask_ratio=mask_ratio, decoder_depth=decoder_depth, decoder_embed_dim=decoder_embed_dim, pretrained=pretrained, ) elif model_name == 'mae_patch_selection_vit_b16': return MAEPatchSelectionViT( num_classes=num_classes, keep_ratio=keep_ratio, pretrained=pretrained, drop_path_rate=drop_path_rate, decoder_embed_dim=decoder_embed_dim, decoder_depth=decoder_depth, img_size=img_size, ) else: raise ValueError(f"Unknown model: {model_name}")