| | import os |
| | import torch |
| | from torch import nn |
| | from einops import rearrange, repeat |
| | from torch import einsum |
| |
|
| | class PerceiverAttention(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_head=64, |
| | heads=8 |
| | ): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 |
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm_media = nn.LayerNorm(dim) |
| | self.norm_learns = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias=False) |
| | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
| | self.to_out = nn.Linear(inner_dim, dim, bias=False) |
| |
|
| | def forward(self, x, learns): |
| | x = self.norm_media(x) |
| | learns = self.norm_learns(learns) |
| |
|
| | b, n, h = *x.shape[:2], self.heads |
| |
|
| | q = self.to_q(learns) |
| |
|
| | |
| | kv_input = torch.cat((x, learns), dim=-2) |
| | k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
| |
|
| | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) |
| |
|
| | q = q * self.scale |
| |
|
| | |
| | sim = einsum('b h i d, b h j d -> b h i j', q, k) |
| | sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
| | attn = sim.softmax(dim=-1) |
| |
|
| | out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | return self.to_out(out) |
| |
|
| |
|
| | class PerceiverResampler(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | depth=2, |
| | dim_head=64, |
| | heads=8, |
| | num_learns=3, |
| | ff_mult=4, |
| | ): |
| | super().__init__() |
| | self.learns = nn.Parameter(torch.randn(num_learns, dim)) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(depth): |
| | self.layers.append( |
| | nn.ModuleList( |
| | [ |
| | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
| | FeedForward(dim=dim, mult=ff_mult), |
| | ] |
| | ) |
| | ) |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| |
|
| | def forward(self, x): |
| | """ |
| | Args: |
| | x (torch.Tensor): image features |
| | shape (b, 256, 4096) |
| | Returns: |
| | shape (b, 3, 4096) where 3 is self.num_learns |
| | """ |
| | b, n, d = x.shape |
| |
|
| | |
| | learns = repeat(self.learns, "n d -> b n d", b=b) |
| |
|
| | |
| | for attn, ff in self.layers: |
| | |
| | learns = attn(x, learns) + learns |
| | learns = ff(learns) + learns |
| |
|
| | return self.norm(learns) |
| |
|
| | |
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, mult=4): |
| | super().__init__() |
| | self.net = nn.Sequential( |
| | nn.LayerNorm(dim), |
| | nn.Linear(dim, dim * mult), |
| | nn.GELU(), |
| | nn.Linear(dim * mult, dim), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |