fae-spatial-s4-chess / pixel_decoder_mae.py
mk322's picture
Upload pixel_decoder_mae.py with huggingface_hub
b90fb2e verified
"""
Pixel Decoder: ViT-MAE style decoder following RAE architecture.
Takes 576Γ—embed_dim ViT features and reconstructs 384Γ—384Γ—3 images.
Architecture: ViT-L decoder (24 layers, hidden=1024, heads=16, intermediate=4096).
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─── Sincos Positional Embeddings ───────────────────────────────────────────
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size])
emb_h = get_1d_sincos_pos_embed(embed_dim // 2, grid[0].reshape(-1))
emb_w = get_1d_sincos_pos_embed(embed_dim // 2, grid[1].reshape(-1))
emb = np.concatenate([emb_h, emb_w], axis=1)
if add_cls_token:
emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0)
return emb
def get_1d_sincos_pos_embed(embed_dim, pos):
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
pos = pos.reshape(-1)
out = np.einsum("m,d->md", pos, omega)
return np.concatenate([np.sin(out), np.cos(out)], axis=1)
# ─── Transformer Components ────────────────────────────────────────────────
class MAESelfAttention(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.query = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.key = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.value = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
self.out_proj = nn.Linear(hidden_size, hidden_size)
self.attn_drop = attn_drop
def forward(self, x):
B, N, C = x.shape
q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop if self.training else 0.0)
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
return self.out_proj(x)
class MAEBlock(nn.Module):
"""Standard ViT block: pre-norm self-attention + pre-norm FFN."""
def __init__(self, hidden_size, num_heads, intermediate_size, hidden_act="gelu",
qkv_bias=True, layer_norm_eps=1e-6):
super().__init__()
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = MAESelfAttention(hidden_size, num_heads, qkv_bias=qkv_bias)
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.intermediate = nn.Linear(hidden_size, intermediate_size)
self.output_proj = nn.Linear(intermediate_size, hidden_size)
self.act_fn = nn.GELU()
def forward(self, x):
# Self-attention with residual
x = x + self.attention(self.layernorm_before(x))
# FFN with residual
h = self.layernorm_after(x)
h = self.act_fn(self.intermediate(h))
x = x + self.output_proj(h)
return x
# ─── Main Pixel Decoder ────────────────────────────────────────────────────
class PixelDecoderMAE(nn.Module):
"""
ViT-MAE style pixel decoder following RAE.
Input: [B, 576, input_dim] ViT features (or FAE-reconstructed features)
Output: [B, 3, 384, 384] reconstructed images
Architecture (ViT-L):
- Linear projection: input_dim β†’ decoder_hidden_size
- Trainable CLS token + sincos positional embeddings
- 24 Transformer blocks
- LayerNorm + linear head β†’ patch_sizeΒ² Γ— 3 per token
- Unpatchify β†’ full image
"""
def __init__(self, input_dim=1152, decoder_hidden_size=1024,
decoder_num_layers=24, decoder_num_heads=16,
decoder_intermediate_size=4096, patch_size=16,
img_size=384, num_channels=3, layer_norm_eps=1e-6):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_channels = num_channels
self.grid_size = img_size // patch_size # 24
self.num_patches = self.grid_size ** 2 # 576
# Project encoder features to decoder dimension + normalize
self.decoder_embed = nn.Linear(input_dim, decoder_hidden_size)
self.embed_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
# Trainable CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
# Fixed sincos positional embeddings (576 patches + 1 CLS)
pos_embed = get_2d_sincos_pos_embed(decoder_hidden_size, self.grid_size, add_cls_token=True)
self.decoder_pos_embed = nn.Parameter(
torch.from_numpy(pos_embed).float().unsqueeze(0),
requires_grad=False
)
# Transformer decoder blocks
self.decoder_layers = nn.ModuleList([
MAEBlock(
hidden_size=decoder_hidden_size,
num_heads=decoder_num_heads,
intermediate_size=decoder_intermediate_size,
layer_norm_eps=layer_norm_eps,
)
for _ in range(decoder_num_layers)
])
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
# Prediction head: project to pixel patches
self.decoder_pred = nn.Linear(
decoder_hidden_size, patch_size ** 2 * num_channels
)
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.cls_token, std=0.02)
# Initialize decoder_embed like a linear layer
nn.init.xavier_uniform_(self.decoder_embed.weight)
if self.decoder_embed.bias is not None:
nn.init.zeros_(self.decoder_embed.bias)
# Initialize decoder_pred
nn.init.xavier_uniform_(self.decoder_pred.weight)
if self.decoder_pred.bias is not None:
nn.init.zeros_(self.decoder_pred.bias)
def unpatchify(self, x):
"""
x: [B, num_patches, patch_sizeΒ²Γ—3]
Returns: [B, 3, H, W]
"""
p = self.patch_size
h = w = self.grid_size
c = self.num_channels
x = x.reshape(-1, h, w, p, p, c)
x = torch.einsum("nhwpqc->nchpwq", x)
return x.reshape(-1, c, h * p, w * p)
def forward(self, features, noise_tau=0.0):
"""
Args:
features: [B, 576, input_dim] ViT features
noise_tau: max noise level applied AFTER normalization (where stdβ‰ˆ1)
Returns:
images: [B, 3, 384, 384] reconstructed images in [-1, 1]
"""
# Project to decoder dimension and normalize
x = self.embed_norm(self.decoder_embed(features)) # [B, 576, decoder_hidden]
# Add noise after normalization (features now have stdβ‰ˆ1, so tau=0.8 is meaningful)
if noise_tau > 0 and self.training:
noise_sigma = noise_tau * torch.rand(
(x.size(0),) + (1,) * (len(x.shape) - 1), device=x.device
)
x = x + noise_sigma * torch.randn_like(x)
# Prepend CLS token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # [B, 577, decoder_hidden]
# Add positional embeddings
x = x + self.decoder_pos_embed
# Transformer blocks
for layer in self.decoder_layers:
x = layer(x)
x = self.decoder_norm(x)
# Predict pixel patches (remove CLS token)
x = self.decoder_pred(x[:, 1:, :]) # [B, 576, patch_sizeΒ²Γ—3]
# Unpatchify to full image
img = self.unpatchify(x) # [B, 3, 384, 384]
return img
class PatchGANDiscriminator(nn.Module):
"""PatchGAN discriminator for adversarial loss."""
def __init__(self, in_channels=3, ndf=64):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1),
nn.InstanceNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1),
nn.InstanceNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1),
)
def forward(self, x):
return self.model(x)