Spaces:
Runtime error
Runtime error
File size: 4,637 Bytes
34f251f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
from einops import rearrange
from einops_exts import rearrange_many
from torch import einsum, nn
import math
def exists(val):
return val is not None
class FeedForward(nn.Module):
def __init__(self, dim, dtype, reduce_factor = 1):
super().__init__()
mult = 4
self.norm = nn.LayerNorm(dim, dtype=dtype)
inner_dim = int(dim * mult) // reduce_factor
self.fc1 = nn.Linear(dim, inner_dim, dtype=dtype)
self.fc2 = nn.Linear(inner_dim, dim, dtype=dtype)
self.act = nn.GELU()
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
# cross attention
class CrossAttention(nn.Module):
def __init__(
self,
*,
dim_text,
dim_visual,
dtype,
dim_head=64,
reduce_factor=1
):
super().__init__()
self.scale = dim_head**-0.5
max_dim = max(dim_text, dim_visual)
self.heads = max_dim // dim_head
assert max_dim % dim_head == 0, f"Number of heads in CrossAttention is not an int - {self.heads}"
inner_dim = max_dim // reduce_factor
self.norm = nn.LayerNorm(dim_text, dtype=dtype)
self.to_q = nn.Linear(dim_text, inner_dim, dtype=dtype)
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, dtype=dtype)
#self.to_kv_second = nn.Linear(dim_visual, inner_dim * 2)
self.to_out = nn.Linear(inner_dim, dim_text, dtype=dtype)
#self.g = []
#self.l = []
def forward(self, x, media):
"""
Args:
x (torch.Tensor): text features
shape (B, txt_seq, D_txt)
media (torch.Tensor): image features
shape (B, img_seq, D_img) where img_seq is the number of concatenated features from the ViT. For example:
for an encoder of 224x224 with patch size 14 and processing images of 896x896 (with 3 levels) it will be (1 + 4 + 16) * 257 = 5397
"""
h = self.heads
x = self.norm(x)
q = self.to_q(x)
k, v = self.to_kv(media).chunk(2, dim=-1)
"""k_s, v_s = self.to_kv(media[:, 257:, :]).chunk(2, dim=-1)
k = torch.cat((k, k_s), 1)
v = torch.cat((v, v_s), 1)"""
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("... i d, ... j d -> ... i j", q, k)
attn = sim.softmax(dim=-1)
#idk = torch.mean(attn.squeeze()[:, 65:, :], (0, 1))
#self.g.append(torch.sum(idk[:257]).item())
#self.l.append(torch.sum(idk[257:]).item())
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
# cross attention
class CrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim_text,
dim_visual,
dtype,
dim_head=64,
reduce_factor = 1,
layer_idx=0,
n_decoder_layers = 24
):
super().__init__()
self.attn = CrossAttention(
dim_text=dim_text,
dim_visual=dim_visual,
dim_head=dim_head,
reduce_factor=reduce_factor,
dtype=dtype
)
self.ff = FeedForward(dim_text, reduce_factor=reduce_factor, dtype=dtype)
self.layer_idx = layer_idx
self.n_decoder_layers = n_decoder_layers
self.apply(self._init_weights)
def forward(
self,
x,
media
):
x = (
self.attn(
x,
media
)
+ x
)
x = self.ff(x) + x
return x
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.01)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "fc2.weight" or name == "to_out.weight":
p.data.normal_(mean=0.0, std=(0.01 / math.sqrt(2 * max(self.n_decoder_layers, 36)))) |