| | """ |
| | Pixel Decoder: ViT-MAE style decoder following RAE architecture. |
| | Takes 576Γembed_dim ViT features and reconstructs 384Γ384Γ3 images. |
| | Architecture: ViT-L decoder (24 layers, hidden=1024, heads=16, intermediate=4096). |
| | """ |
| |
|
| | import math |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | |
| |
|
| | def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): |
| | grid_h = np.arange(grid_size, dtype=np.float32) |
| | grid_w = np.arange(grid_size, dtype=np.float32) |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0).reshape([2, 1, grid_size, grid_size]) |
| |
|
| | emb_h = get_1d_sincos_pos_embed(embed_dim // 2, grid[0].reshape(-1)) |
| | emb_w = get_1d_sincos_pos_embed(embed_dim // 2, grid[1].reshape(-1)) |
| | emb = np.concatenate([emb_h, emb_w], axis=1) |
| |
|
| | if add_cls_token: |
| | emb = np.concatenate([np.zeros([1, embed_dim]), emb], axis=0) |
| | return emb |
| |
|
| |
|
| | def get_1d_sincos_pos_embed(embed_dim, pos): |
| | omega = np.arange(embed_dim // 2, dtype=float) |
| | omega /= embed_dim / 2.0 |
| | omega = 1.0 / 10000**omega |
| |
|
| | pos = pos.reshape(-1) |
| | out = np.einsum("m,d->md", pos, omega) |
| | return np.concatenate([np.sin(out), np.cos(out)], axis=1) |
| |
|
| |
|
| | |
| |
|
| | class MAESelfAttention(nn.Module): |
| | def __init__(self, hidden_size, num_heads, qkv_bias=True, attn_drop=0.0, proj_drop=0.0): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.head_dim = hidden_size // num_heads |
| |
|
| | self.query = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) |
| | self.key = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) |
| | self.value = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) |
| | self.out_proj = nn.Linear(hidden_size, hidden_size) |
| | self.attn_drop = attn_drop |
| |
|
| | def forward(self, x): |
| | B, N, C = x.shape |
| | q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
| | k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
| | v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
| |
|
| | x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop if self.training else 0.0) |
| | x = x.permute(0, 2, 1, 3).reshape(B, N, C) |
| | return self.out_proj(x) |
| |
|
| |
|
| | class MAEBlock(nn.Module): |
| | """Standard ViT block: pre-norm self-attention + pre-norm FFN.""" |
| | def __init__(self, hidden_size, num_heads, intermediate_size, hidden_act="gelu", |
| | qkv_bias=True, layer_norm_eps=1e-6): |
| | super().__init__() |
| | self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| | self.attention = MAESelfAttention(hidden_size, num_heads, qkv_bias=qkv_bias) |
| | self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps) |
| | self.intermediate = nn.Linear(hidden_size, intermediate_size) |
| | self.output_proj = nn.Linear(intermediate_size, hidden_size) |
| | self.act_fn = nn.GELU() |
| |
|
| | def forward(self, x): |
| | |
| | x = x + self.attention(self.layernorm_before(x)) |
| | |
| | h = self.layernorm_after(x) |
| | h = self.act_fn(self.intermediate(h)) |
| | x = x + self.output_proj(h) |
| | return x |
| |
|
| |
|
| | |
| |
|
| | class PixelDecoderMAE(nn.Module): |
| | """ |
| | ViT-MAE style pixel decoder following RAE. |
| | |
| | Input: [B, 576, input_dim] ViT features (or FAE-reconstructed features) |
| | Output: [B, 3, 384, 384] reconstructed images |
| | |
| | Architecture (ViT-L): |
| | - Linear projection: input_dim β decoder_hidden_size |
| | - Trainable CLS token + sincos positional embeddings |
| | - 24 Transformer blocks |
| | - LayerNorm + linear head β patch_sizeΒ² Γ 3 per token |
| | - Unpatchify β full image |
| | """ |
| |
|
| | def __init__(self, input_dim=1152, decoder_hidden_size=1024, |
| | decoder_num_layers=24, decoder_num_heads=16, |
| | decoder_intermediate_size=4096, patch_size=16, |
| | img_size=384, num_channels=3, layer_norm_eps=1e-6): |
| | super().__init__() |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.num_channels = num_channels |
| | self.grid_size = img_size // patch_size |
| | self.num_patches = self.grid_size ** 2 |
| |
|
| | |
| | self.decoder_embed = nn.Linear(input_dim, decoder_hidden_size) |
| | self.embed_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) |
| |
|
| | |
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) |
| |
|
| | |
| | pos_embed = get_2d_sincos_pos_embed(decoder_hidden_size, self.grid_size, add_cls_token=True) |
| | self.decoder_pos_embed = nn.Parameter( |
| | torch.from_numpy(pos_embed).float().unsqueeze(0), |
| | requires_grad=False |
| | ) |
| |
|
| | |
| | self.decoder_layers = nn.ModuleList([ |
| | MAEBlock( |
| | hidden_size=decoder_hidden_size, |
| | num_heads=decoder_num_heads, |
| | intermediate_size=decoder_intermediate_size, |
| | layer_norm_eps=layer_norm_eps, |
| | ) |
| | for _ in range(decoder_num_layers) |
| | ]) |
| |
|
| | self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps) |
| |
|
| | |
| | self.decoder_pred = nn.Linear( |
| | decoder_hidden_size, patch_size ** 2 * num_channels |
| | ) |
| |
|
| | self._init_weights() |
| |
|
| | def _init_weights(self): |
| | nn.init.normal_(self.cls_token, std=0.02) |
| | |
| | nn.init.xavier_uniform_(self.decoder_embed.weight) |
| | if self.decoder_embed.bias is not None: |
| | nn.init.zeros_(self.decoder_embed.bias) |
| | |
| | nn.init.xavier_uniform_(self.decoder_pred.weight) |
| | if self.decoder_pred.bias is not None: |
| | nn.init.zeros_(self.decoder_pred.bias) |
| |
|
| | def unpatchify(self, x): |
| | """ |
| | x: [B, num_patches, patch_sizeΒ²Γ3] |
| | Returns: [B, 3, H, W] |
| | """ |
| | p = self.patch_size |
| | h = w = self.grid_size |
| | c = self.num_channels |
| |
|
| | x = x.reshape(-1, h, w, p, p, c) |
| | x = torch.einsum("nhwpqc->nchpwq", x) |
| | return x.reshape(-1, c, h * p, w * p) |
| |
|
| | def forward(self, features, noise_tau=0.0): |
| | """ |
| | Args: |
| | features: [B, 576, input_dim] ViT features |
| | noise_tau: max noise level applied AFTER normalization (where stdβ1) |
| | Returns: |
| | images: [B, 3, 384, 384] reconstructed images in [-1, 1] |
| | """ |
| | |
| | x = self.embed_norm(self.decoder_embed(features)) |
| |
|
| | |
| | if noise_tau > 0 and self.training: |
| | noise_sigma = noise_tau * torch.rand( |
| | (x.size(0),) + (1,) * (len(x.shape) - 1), device=x.device |
| | ) |
| | x = x + noise_sigma * torch.randn_like(x) |
| |
|
| | |
| | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) |
| | x = torch.cat([cls_tokens, x], dim=1) |
| |
|
| | |
| | x = x + self.decoder_pos_embed |
| |
|
| | |
| | for layer in self.decoder_layers: |
| | x = layer(x) |
| |
|
| | x = self.decoder_norm(x) |
| |
|
| | |
| | x = self.decoder_pred(x[:, 1:, :]) |
| |
|
| | |
| | img = self.unpatchify(x) |
| |
|
| | return img |
| |
|
| |
|
| | class PatchGANDiscriminator(nn.Module): |
| | """PatchGAN discriminator for adversarial loss.""" |
| |
|
| | def __init__(self, in_channels=3, ndf=64): |
| | super().__init__() |
| | self.model = nn.Sequential( |
| | nn.Conv2d(in_channels, ndf, 4, stride=2, padding=1), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(ndf, ndf * 2, 4, stride=2, padding=1), |
| | nn.InstanceNorm2d(ndf * 2), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(ndf * 2, ndf * 4, 4, stride=2, padding=1), |
| | nn.InstanceNorm2d(ndf * 4), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(ndf * 4, ndf * 8, 4, stride=1, padding=1), |
| | nn.InstanceNorm2d(ndf * 8), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | nn.Conv2d(ndf * 8, 1, 4, stride=1, padding=1), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.model(x) |
| |
|