visinject / decoder.py
jeffliulab's picture
Initial Space deployment: Stage 2 fusion demo (CPU, free tier)
e1887f1 verified
"""
AnyAttack Decoder Network.
Takes a CLIP embedding (512-dim for ViT-B/32) and generates an adversarial
noise image (3 x 224 x 224). The noise is clamped externally to [-eps, eps].
Architecture:
FC(512 -> 256*14*14) -> 4x(ResBlock + UpBlock) -> Conv(16->3)
ResBlocks include EfficientAttention for spatial self-attention.
Adapted from: https://github.com/jiamingzhang94/AnyAttack/blob/master/models/model.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EfficientAttention(nn.Module):
"""Linear-complexity spatial self-attention (O(N*C^2) instead of O(N^2*C))."""
def __init__(self, in_channels: int, key_channels: int,
head_count: int, value_channels: int):
super().__init__()
self.key_channels = key_channels
self.head_count = head_count
self.value_channels = value_channels
self.keys = nn.Conv2d(in_channels, key_channels, 1)
self.queries = nn.Conv2d(in_channels, key_channels, 1)
self.values = nn.Conv2d(in_channels, value_channels, 1)
self.reprojection = nn.Conv2d(value_channels, in_channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n, _, h, w = x.size()
keys = self.keys(x).reshape(n, self.key_channels, h * w)
queries = self.queries(x).reshape(n, self.key_channels, h * w)
values = self.values(x).reshape(n, self.value_channels, h * w)
head_key_ch = self.key_channels // self.head_count
head_val_ch = self.value_channels // self.head_count
attended = []
for i in range(self.head_count):
k = F.softmax(keys[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=2)
q = F.softmax(queries[:, i * head_key_ch:(i + 1) * head_key_ch, :], dim=1)
v = values[:, i * head_val_ch:(i + 1) * head_val_ch, :]
context = k @ v.transpose(1, 2)
out = (context.transpose(1, 2) @ q).reshape(n, head_val_ch, h, w)
attended.append(out)
aggregated = torch.cat(attended, dim=1)
return self.reprojection(aggregated) + x
class ResBlock(nn.Module):
"""Residual block with EfficientAttention."""
def __init__(self, in_ch: int, out_ch: int,
key_ch: int, head_count: int, val_ch: int):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
self.bn2 = nn.BatchNorm2d(out_ch)
self.act = nn.LeakyReLU(0.2, inplace=True)
self.attention = EfficientAttention(out_ch, key_ch, head_count, val_ch)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.skip(x)
out = self.act(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.attention(out)
return self.act(out + residual)
class UpBlock(nn.Module):
"""2x spatial upsampling with conv."""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
self.bn = nn.BatchNorm2d(out_ch)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.bn(self.conv(self.up(x))))
class Decoder(nn.Module):
"""
AnyAttack noise generator: CLIP embedding -> adversarial noise image.
Args:
embed_dim: Input embedding dimension (512 for ViT-B/32, 1024 for ViT-L/14).
img_channels: Output image channels (3 for RGB).
img_size: Output spatial resolution (224).
"""
def __init__(self, embed_dim: int = 512, img_channels: int = 3, img_size: int = 224):
super().__init__()
self.init_size = img_size // 16 # 14 for 224
self.fc = nn.Sequential(
nn.Linear(embed_dim, 256 * self.init_size ** 2)
)
self.blocks = nn.ModuleList([
ResBlock(256, 256, 64, 8, 256),
UpBlock(256, 128),
ResBlock(128, 128, 32, 8, 128),
UpBlock(128, 64),
ResBlock(64, 64, 16, 8, 64),
UpBlock(64, 32),
ResBlock(32, 32, 8, 8, 32),
UpBlock(32, 16),
ResBlock(16, 16, 4, 8, 16),
])
self.head = nn.Conv2d(16, img_channels, 3, 1, 1)
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""
Generate noise from CLIP embedding.
Args:
embedding: (B, embed_dim) CLIP image embedding.
Returns:
(B, 3, img_size, img_size) raw noise (NOT clamped to [-eps, eps]).
"""
out = self.fc(embedding.float().view(embedding.size(0), -1))
out = out.view(out.size(0), 256, self.init_size, self.init_size)
for block in self.blocks:
out = block(out)
return self.head(out)