| """FAE with CNN spatial pooling for token reduction. |
| |
| Encoder: CNN downsample (24×24 → H'×W') + self-attention + project to latent_dim |
| Decoder: project up + ViT layers at compressed resolution + CNN upsample (H'×W' → 24×24) |
| |
| pool_factor=2: 576 → 144 tokens (s2) |
| pool_factor=4: 576 → 36 tokens (s4) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import sys, os |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from utils import RMSNorm |
| from models.feature_decoder import RotaryPositionalEmbedding2D, ViTDecoderBlock |
|
|
|
|
| class CNNDownsample(nn.Module): |
| """Spatial downsampling with strided convolutions. |
| Each layer does 2x downsample. Stacks log2(pool_factor) layers. |
| """ |
|
|
| def __init__(self, dim, pool_factor): |
| super().__init__() |
| assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" |
| num_layers = int(math.log2(pool_factor)) |
| layers = [] |
| for _ in range(num_layers): |
| layers.extend([ |
| nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1), |
| nn.GELU(), |
| ]) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| """x: [B, C, H, W] → [B, C, H/pf, W/pf]""" |
| return self.net(x) |
|
|
|
|
| class CNNUpsample(nn.Module): |
| """Spatial upsampling with transposed convolutions. |
| Each layer does 2x upsample. Stacks log2(pool_factor) layers. |
| """ |
|
|
| def __init__(self, dim, pool_factor): |
| super().__init__() |
| assert pool_factor in (2, 4), f"pool_factor must be 2 or 4, got {pool_factor}" |
| num_layers = int(math.log2(pool_factor)) |
| layers = [] |
| for _ in range(num_layers): |
| layers.extend([ |
| nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1), |
| nn.GELU(), |
| ]) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| """x: [B, C, H', W'] → [B, C, H'*pf, W'*pf]""" |
| return self.net(x) |
|
|
|
|
| class FAESpatialEncoder(nn.Module): |
| """FAE Encoder with CNN spatial pooling. |
| |
| Input: [B, 576, embed_dim] |
| Output: [B, N_compressed, latent_dim] |
| where N_compressed = (24/pool_factor)^2 |
| """ |
|
|
| def __init__(self, embed_dim=1152, latent_dim=32, num_heads=16, |
| pool_factor=2, grid_size=24, use_vae=True): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.latent_dim = latent_dim |
| self.pool_factor = pool_factor |
| self.grid_size = grid_size |
| self.compressed_grid = grid_size // pool_factor |
| self.use_vae = use_vae |
|
|
| |
| self.downsample = CNNDownsample(embed_dim, pool_factor) |
|
|
| |
| self.norm1 = RMSNorm(embed_dim) |
| self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
|
|
| |
| self.norm2 = RMSNorm(embed_dim) |
| ffn_dim = int(embed_dim * 2.7) |
| self.w1 = nn.Linear(embed_dim, ffn_dim, bias=False) |
| self.w2 = nn.Linear(ffn_dim, embed_dim, bias=False) |
| self.w3 = nn.Linear(embed_dim, ffn_dim, bias=False) |
|
|
| |
| self.proj = nn.Linear(embed_dim, latent_dim) |
|
|
| |
| if use_vae: |
| self.mu_head = nn.Linear(latent_dim, latent_dim) |
| self.logvar_head = nn.Linear(latent_dim, latent_dim) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: [B, N, embed_dim] where N = grid_size^2 = 576 |
| Returns: |
| z_sample: [B, N_compressed, latent_dim] |
| mu, logvar: same shape |
| """ |
| B, N, D = x.shape |
|
|
| |
| x = x.transpose(1, 2).reshape(B, D, self.grid_size, self.grid_size) |
| x = self.downsample(x) |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| normed = self.norm1(x) |
| x = x + self.self_attn(normed, normed, normed)[0] |
|
|
| |
| h = self.norm2(x) |
| x = x + self.w2(F.silu(self.w1(h)) * self.w3(h)) |
|
|
| |
| z = self.proj(x) |
|
|
| if not self.use_vae: |
| return z, z, torch.zeros_like(z) |
|
|
| mu = self.mu_head(z) |
| logvar = self.logvar_head(z) |
|
|
| if self.training: |
| std = torch.exp(0.5 * logvar) |
| z_sample = mu + std * torch.randn_like(std) |
| else: |
| z_sample = mu |
|
|
| return z_sample, mu, logvar |
|
|
|
|
| class FAESpatialDecoder(nn.Module): |
| """FAE Decoder with CNN spatial upsampling. |
| |
| Input: [B, N_compressed, latent_dim] |
| Output: [B, 576, output_dim] |
| |
| ViT layers operate at compressed resolution, then CNN upsamples. |
| """ |
|
|
| def __init__(self, latent_dim=32, output_dim=1152, num_layers=6, |
| num_heads=16, ffn_mult=2.7, pool_factor=2, grid_size=24): |
| super().__init__() |
| self.output_dim = output_dim |
| self.pool_factor = pool_factor |
| self.grid_size = grid_size |
| self.compressed_grid = grid_size // pool_factor |
|
|
| |
| self.input_proj = nn.Linear(latent_dim, output_dim) |
|
|
| |
| head_dim = output_dim // num_heads |
| self.rope = RotaryPositionalEmbedding2D(head_dim, grid_size=self.compressed_grid) |
|
|
| |
| self.layers = nn.ModuleList([ |
| ViTDecoderBlock(output_dim, num_heads, ffn_mult) |
| for _ in range(num_layers) |
| ]) |
| self.pre_upsample_norm = RMSNorm(output_dim) |
|
|
| |
| self.upsample = CNNUpsample(output_dim, pool_factor) |
|
|
| |
| self.final_norm = RMSNorm(output_dim) |
|
|
| def forward(self, z): |
| """ |
| Args: |
| z: [B, N_compressed, latent_dim] |
| Returns: |
| x_hat: [B, N_full, output_dim] where N_full = grid_size^2 |
| """ |
| B = z.shape[0] |
| x = self.input_proj(z) |
|
|
| rope_cos, rope_sin = self.rope(x.shape[1], x.device) |
|
|
| for layer in self.layers: |
| x = layer(x, rope_cos, rope_sin) |
|
|
| x = self.pre_upsample_norm(x) |
|
|
| |
| x = x.transpose(1, 2).reshape(B, self.output_dim, |
| self.compressed_grid, self.compressed_grid) |
| x = self.upsample(x) |
| x = x.flatten(2).transpose(1, 2) |
|
|
| return self.final_norm(x) |
|
|