|
""" |
|
Perceiver code is based on Aurora: https://github.com/microsoft/aurora/blob/main/aurora/model/perceiver.py |
|
|
|
Some conventions for notation: |
|
B - Batch |
|
T - Time |
|
H - Height (pixel space) |
|
W - Width (pixel space) |
|
HT - Height (token space) |
|
WT - Width (token space) |
|
ST - Sequence (token space) |
|
C - Input channels |
|
D - Model (embedding) dimension |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from timm.models.layers import trunc_normal_ |
|
|
|
|
|
class PatchEmbed3D(nn.Module): |
|
"""Timeseries Image to Patch Embedding""" |
|
|
|
def __init__( |
|
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, time_dim=2 |
|
): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.patch_size = patch_size |
|
self.embed_dim = embed_dim |
|
self.time_dim = time_dim |
|
|
|
self.proj = nn.Conv2d( |
|
in_chans * time_dim, |
|
embed_dim, |
|
kernel_size=(patch_size, patch_size), |
|
stride=(patch_size, patch_size), |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: Tensor of shape (B, C, T, H, W) |
|
Returns: |
|
Tensor of shape (B, ST, D) |
|
""" |
|
B, C, T, H, W = x.shape |
|
x = self.proj(x.flatten(1, 2)) |
|
x = rearrange(x, "B D HT WT -> B (HT WT) D") |
|
return x |
|
|
|
|
|
class LinearEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
img_size=224, |
|
patch_size=16, |
|
in_chans=3, |
|
time_dim=2, |
|
embed_dim=768, |
|
drop_rate=0.0, |
|
): |
|
super().__init__() |
|
|
|
self.num_patches = (img_size // patch_size) ** 2 |
|
|
|
self.patch_embed = PatchEmbed3D( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
time_dim=time_dim, |
|
) |
|
|
|
self._generate_position_encoding(img_size, patch_size, embed_dim) |
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
|
def _generate_position_encoding(self, img_size, patch_size, embed_dim): |
|
""" |
|
Generates a positional encoding signal for the model. The generated |
|
positional encoding signal is stored as a buffer (`self.fourier_signal`). |
|
|
|
Args: |
|
img_size (int): The size of the input image. |
|
patch_size (int): The size of each patch in the image. |
|
embed_dim (int): The embedding dimension of the model. |
|
|
|
Returns: |
|
None. |
|
""" |
|
|
|
x = torch.linspace(0.0, 1.0, img_size // patch_size) |
|
y = torch.linspace(0.0, 1.0, img_size // patch_size) |
|
x, y = torch.meshgrid(x, y, indexing="xy") |
|
fourier_signal = [] |
|
|
|
frequencies = torch.linspace(1, (img_size // patch_size) / 2.0, embed_dim // 4) |
|
|
|
for f in frequencies: |
|
fourier_signal.extend( |
|
[ |
|
torch.cos(2.0 * torch.pi * f * x), |
|
torch.sin(2.0 * torch.pi * f * x), |
|
torch.cos(2.0 * torch.pi * f * y), |
|
torch.sin(2.0 * torch.pi * f * y), |
|
] |
|
) |
|
fourier_signal = torch.stack(fourier_signal, dim=2) |
|
fourier_signal = rearrange(fourier_signal, "h w c -> 1 (h w) c") |
|
self.register_buffer("pos_embed", fourier_signal) |
|
|
|
def forward(self, x, dt): |
|
""" |
|
Args: |
|
x: Tensor of shape (B, C, T, H, W). |
|
dt: Tensor of shape (B, T). However it is not used. |
|
Returns: |
|
Tensor of shape (B, ST, D) |
|
""" |
|
x = self.patch_embed(x) |
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
|
|
return x |
|
|
|
|
|
class LinearDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
patch_size: int, |
|
out_chans: int, |
|
embed_dim: int, |
|
): |
|
""" |
|
Args: |
|
patch_size: patch size |
|
in_chans: number of iput channels |
|
embed_dim: embedding dimension |
|
""" |
|
super().__init__() |
|
|
|
self.unembed = nn.Sequential( |
|
nn.Conv2d( |
|
in_channels=embed_dim, |
|
out_channels=(patch_size**2) * out_chans, |
|
kernel_size=1, |
|
), |
|
nn.PixelShuffle(patch_size), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape (B, L, D). For ensembles, we have implicitly B = (B E). |
|
Returns: |
|
Tensor of shape (B C H W). |
|
Here |
|
- C equals num_queries |
|
- H == W == sqrt(L) x patch_size |
|
""" |
|
|
|
_, L, _ = x.shape |
|
H_token = W_token = int(L**0.5) |
|
x = rearrange(x, "B (H W) D -> B D H W", H=H_token, W=W_token) |
|
|
|
|
|
x = self.unembed(x) |
|
|
|
return x |
|
|
|
|
|
class MLP(nn.Module): |
|
"""A simple one-hidden-layer MLP.""" |
|
|
|
def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None: |
|
"""Initialise. |
|
|
|
Args: |
|
dim (int): Input dimensionality. |
|
hidden_features (int): Width of the hidden layer. |
|
dropout (float, optional): Drop-out rate. Defaults to no drop-out. |
|
""" |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_features), |
|
nn.GELU(), |
|
nn.Linear(hidden_features, dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Run the MLP.""" |
|
return self.net(x) |
|
|
|
|
|
class PerceiverAttention(nn.Module): |
|
"""Cross attention module from the Perceiver architecture.""" |
|
|
|
def __init__( |
|
self, |
|
latent_dim: int, |
|
context_dim: int, |
|
head_dim: int = 64, |
|
num_heads: int = 8, |
|
) -> None: |
|
"""Initialise. |
|
|
|
Args: |
|
latent_dim (int): Dimensionality of the latent features given as input. |
|
context_dim (int): Dimensionality of the context features also given as input. |
|
head_dim (int): Attention head dimensionality. |
|
num_heads (int): Number of heads. |
|
""" |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.head_dim = head_dim |
|
self.inner_dim = head_dim * num_heads |
|
|
|
self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False) |
|
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False) |
|
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False) |
|
|
|
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
|
"""Run the cross-attention module. |
|
|
|
Args: |
|
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, Latent_D)` |
|
where typically `L1 < L2` and `Latent_D <= Context_D`. `Latent_D` is equal to |
|
`self.latent_dim`. |
|
x (:class:`torch.Tensor`): Context features of shape `(B, L2, Context_D)`. |
|
|
|
Returns: |
|
:class:`torch.Tensor`: Latent values of shape `(B, L1, Latent_D)`. |
|
""" |
|
h = self.num_heads |
|
|
|
q = self.to_q(latents) |
|
k, v = self.to_kv(x).chunk(2, dim=-1) |
|
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v)) |
|
|
|
out = F.scaled_dot_product_attention(q, k, v) |
|
out = rearrange(out, "B H L1 D -> B L1 (H D)") |
|
return self.to_out(out) |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
"""Perceiver Resampler module from the Flamingo paper.""" |
|
|
|
def __init__( |
|
self, |
|
latent_dim: int, |
|
context_dim: int, |
|
depth: int = 1, |
|
head_dim: int = 64, |
|
num_heads: int = 16, |
|
mlp_ratio: float = 4.0, |
|
drop: float = 0.0, |
|
residual_latent: bool = True, |
|
ln_eps: float = 1e-5, |
|
) -> None: |
|
"""Initialise. |
|
|
|
Args: |
|
latent_dim (int): Dimensionality of the latent features given as input. |
|
context_dim (int): Dimensionality of the context features also given as input. |
|
depth (int, optional): Number of attention layers. |
|
head_dim (int, optional): Attention head dimensionality. Defaults to `64`. |
|
num_heads (int, optional): Number of heads. Defaults to `16` |
|
mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the |
|
input for all MLPs. Defaults to `4.0`. |
|
drop (float, optional): Drop-out rate. Defaults to no drop-out. |
|
residual_latent (bool, optional): Use residual attention w.r.t. the latent features. |
|
Defaults to `True`. |
|
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to |
|
`1e-5`. |
|
""" |
|
super().__init__() |
|
|
|
self.residual_latent = residual_latent |
|
self.layers = nn.ModuleList([]) |
|
mlp_hidden_dim = int(latent_dim * mlp_ratio) |
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
PerceiverAttention( |
|
latent_dim=latent_dim, |
|
context_dim=context_dim, |
|
head_dim=head_dim, |
|
num_heads=num_heads, |
|
), |
|
MLP( |
|
dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop |
|
), |
|
nn.LayerNorm(latent_dim, eps=ln_eps), |
|
nn.LayerNorm(latent_dim, eps=ln_eps), |
|
] |
|
) |
|
) |
|
|
|
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
|
"""Run the module. |
|
|
|
Args: |
|
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, D1)`. |
|
x (:class:`torch.Tensor`): Context features of shape `(B, L2, D1)`. |
|
|
|
Returns: |
|
torch.Tensor: Latent features of shape `(B, L1, D1)`. |
|
""" |
|
for attn, ff, ln1, ln2 in self.layers: |
|
|
|
|
|
attn_out = ln1(attn(latents, x)) |
|
|
|
|
|
|
|
|
|
|
|
latents = attn_out + latents if self.residual_latent else attn_out |
|
latents = ln2(ff(latents)) + latents |
|
return latents |
|
|
|
|
|
class PerceiverChannelEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
in_chans: int, |
|
img_size: int, |
|
patch_size: int, |
|
time_dim: int, |
|
num_queries: int, |
|
embed_dim: int, |
|
drop_rate: float, |
|
): |
|
super().__init__() |
|
|
|
if embed_dim % 2 != 0: |
|
raise ValueError( |
|
f"Temporal embeddings require `embed_dim` to be even. Currently we have {embed_dim}." |
|
) |
|
|
|
self.num_patches = (img_size // patch_size) ** 2 |
|
self.num_queries = num_queries |
|
self.embed_dim = embed_dim |
|
|
|
self.proj = nn.Conv2d( |
|
in_channels=in_chans * time_dim, |
|
out_channels=in_chans * embed_dim, |
|
kernel_size=patch_size, |
|
stride=patch_size, |
|
groups=in_chans, |
|
) |
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_patches)) |
|
trunc_normal_(self.pos_embed, std=0.02) |
|
|
|
self.latent_queries = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) |
|
trunc_normal_(self.latent_queries, std=0.02) |
|
|
|
self.perceiver = PerceiverResampler( |
|
latent_dim=embed_dim, |
|
context_dim=embed_dim, |
|
depth=1, |
|
head_dim=embed_dim // 16, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
drop=0.0, |
|
residual_latent=False, |
|
ln_eps=1e-5, |
|
) |
|
|
|
self.latent_aggregation = nn.Linear(num_queries * embed_dim, embed_dim) |
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
|
def forward(self, x, dt): |
|
""" |
|
Args: |
|
x: Tensor of shape (B, C, T, H, W) |
|
dt: Tensor of shape (B, T) identifying time deltas. |
|
Returns: |
|
Tensor of shape (B, ST, D) |
|
""" |
|
B, C, T, H, W = x.shape |
|
x = rearrange(x, "B C T H W -> B (C T) H W") |
|
x = self.proj(x) |
|
x = x.flatten(2, 3) |
|
ST = x.shape[2] |
|
assert ST == self.num_patches |
|
x = rearrange(x, "B (C D) ST -> (B C) D ST", B=B, ST=ST, C=C, D=self.embed_dim) |
|
x = x + self.pos_embed |
|
x = rearrange(x, "(B C) D ST -> (B ST) C D", B=B, ST=ST, C=C, D=self.embed_dim) |
|
|
|
|
|
x = self.perceiver(self.latent_queries.expand(B * ST, -1, -1), x) |
|
x = rearrange( |
|
x, |
|
"(B ST) NQ D -> B ST (NQ D)", |
|
B=B, |
|
ST=self.num_patches, |
|
NQ=self.num_queries, |
|
D=self.embed_dim, |
|
) |
|
x = self.latent_aggregation(x) |
|
|
|
assert x.shape[1] == self.num_patches |
|
assert x.shape[2] == self.embed_dim |
|
|
|
x = self.pos_drop(x) |
|
|
|
return x |
|
|
|
|
|
class PerceiverDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
patch_size: int, |
|
out_chans: int, |
|
): |
|
""" |
|
Args: |
|
embed_dim: embedding dimension |
|
patch_size: patch size |
|
out_chans: number of output channels. This determines the number of latent queries. |
|
drop_rate: dropout rate |
|
""" |
|
super().__init__() |
|
|
|
self.embed_dim = embed_dim |
|
self.patch_size = patch_size |
|
self.out_chans = out_chans |
|
|
|
self.latent_queries = nn.Parameter(torch.zeros(1, out_chans, embed_dim)) |
|
trunc_normal_(self.latent_queries, std=0.02) |
|
|
|
self.perceiver = PerceiverResampler( |
|
latent_dim=embed_dim, |
|
context_dim=embed_dim, |
|
depth=1, |
|
head_dim=embed_dim // 16, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
drop=0.0, |
|
residual_latent=False, |
|
ln_eps=1e-5, |
|
) |
|
self.proj = nn.Conv2d( |
|
in_channels=out_chans * embed_dim, |
|
out_channels=out_chans * patch_size**2, |
|
kernel_size=1, |
|
padding=0, |
|
groups=out_chans, |
|
) |
|
self.pixel_shuffle = nn.PixelShuffle(patch_size) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape (B, L, D) For ensembles, we have implicitly B = (B E). |
|
Returns: |
|
Tensor of shape (B C H W). |
|
Here |
|
- C equals out_chans |
|
- H == W == sqrt(L) x patch_size |
|
""" |
|
B, L, D = x.shape |
|
H_token = W_token = int(L**0.5) |
|
|
|
x = rearrange(x, "B L D -> (B L) 1 D") |
|
|
|
x = self.perceiver(self.latent_queries.expand(B * L, -1, -1), x) |
|
x = rearrange(x, "(B H W) C D -> B (C D) H W", H=H_token, W=W_token) |
|
|
|
x = self.proj(x) |
|
|
|
x = self.pixel_shuffle(x) |
|
|
|
return x |
|
|