Pheye / pheye_builder /xattn.py
miguelcarv's picture
first commit
34f251f
raw
history blame
No virus
4.64 kB
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
from einops import rearrange
from einops_exts import rearrange_many
from torch import einsum, nn
import math
def exists(val):
return val is not None
class FeedForward(nn.Module):
def __init__(self, dim, dtype, reduce_factor = 1):
super().__init__()
mult = 4
self.norm = nn.LayerNorm(dim, dtype=dtype)
inner_dim = int(dim * mult) // reduce_factor
self.fc1 = nn.Linear(dim, inner_dim, dtype=dtype)
self.fc2 = nn.Linear(inner_dim, dim, dtype=dtype)
self.act = nn.GELU()
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
# cross attention
class CrossAttention(nn.Module):
def __init__(
self,
*,
dim_text,
dim_visual,
dtype,
dim_head=64,
reduce_factor=1
):
super().__init__()
self.scale = dim_head**-0.5
max_dim = max(dim_text, dim_visual)
self.heads = max_dim // dim_head
assert max_dim % dim_head == 0, f"Number of heads in CrossAttention is not an int - {self.heads}"
inner_dim = max_dim // reduce_factor
self.norm = nn.LayerNorm(dim_text, dtype=dtype)
self.to_q = nn.Linear(dim_text, inner_dim, dtype=dtype)
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, dtype=dtype)
#self.to_kv_second = nn.Linear(dim_visual, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, dim_text, dtype=dtype)
#self.g = []
#self.l = []
def forward(self, x, media):
"""
Args:
x (torch.Tensor): text features
shape (B, txt_seq, D_txt)
media (torch.Tensor): image features
shape (B, img_seq, D_img) where img_seq is the number of concatenated features from the ViT. For example:
for an encoder of 224x224 with patch size 14 and processing images of 896x896 (with 3 levels) it will be (1 + 4 + 16) * 257 = 5397
"""
h = self.heads
x = self.norm(x)
q = self.to_q(x)
k, v = self.to_kv(media).chunk(2, dim=-1)
"""k_s, v_s = self.to_kv(media[:, 257:, :]).chunk(2, dim=-1)
k = torch.cat((k, k_s), 1)
v = torch.cat((v, v_s), 1)"""
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("... i d, ... j d -> ... i j", q, k)
attn = sim.softmax(dim=-1)
#idk = torch.mean(attn.squeeze()[:, 65:, :], (0, 1))
#self.g.append(torch.sum(idk[:257]).item())
#self.l.append(torch.sum(idk[257:]).item())
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# cross attention
class CrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim_text,
dim_visual,
dtype,
dim_head=64,
reduce_factor = 1,
layer_idx=0,
n_decoder_layers = 24
):
super().__init__()
self.attn = CrossAttention(
dim_text=dim_text,
dim_visual=dim_visual,
dim_head=dim_head,
reduce_factor=reduce_factor,
dtype=dtype
)
self.ff = FeedForward(dim_text, reduce_factor=reduce_factor, dtype=dtype)
self.layer_idx = layer_idx
self.n_decoder_layers = n_decoder_layers
self.apply(self._init_weights)
def forward(
self,
x,
media
):
x = (
self.attn(
x,
media
)
+ x
)
x = self.ff(x) + x
return x
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "fc2.weight" or name == "to_out.weight":
p.data.normal_(mean=0.0, std=(0.01 / math.sqrt(2 * max(self.n_decoder_layers, 36))))