simpleParadox
commited on
Commit
•
3a7fb46
1
Parent(s):
3d7108b
Upload flamingo_pytorch.py
Browse files- flamingo_pytorch.py +220 -0
flamingo_pytorch.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, einsum
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from einops_exts import rearrange_many, repeat_many
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
def FeedForward(dim, mult = 4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias = False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias = False)
|
20 |
+
)
|
21 |
+
|
22 |
+
class PerceiverAttention(nn.Module):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
*,
|
26 |
+
dim,
|
27 |
+
dim_head = 64,
|
28 |
+
heads = 8
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.scale = dim_head ** -0.5
|
32 |
+
self.heads = heads
|
33 |
+
inner_dim = dim_head * heads
|
34 |
+
|
35 |
+
self.norm_media = nn.LayerNorm(dim)
|
36 |
+
self.norm_latents = nn.LayerNorm(dim)
|
37 |
+
|
38 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
39 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
40 |
+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
41 |
+
|
42 |
+
def forward(self, x, latents):
|
43 |
+
"""
|
44 |
+
einstein notation
|
45 |
+
b - batch
|
46 |
+
t - time
|
47 |
+
n - sequence
|
48 |
+
d - dimension
|
49 |
+
"""
|
50 |
+
x = self.norm_media(x)
|
51 |
+
latents = self.norm_latents(latents)
|
52 |
+
|
53 |
+
b, m, h = *x.shape[:2], self.heads
|
54 |
+
|
55 |
+
q = self.to_q(latents)
|
56 |
+
|
57 |
+
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
|
58 |
+
kv_input = torch.cat((x, latents), dim = -2)
|
59 |
+
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
|
60 |
+
|
61 |
+
q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)
|
62 |
+
|
63 |
+
q = q * self.scale
|
64 |
+
|
65 |
+
# attention
|
66 |
+
|
67 |
+
sim = einsum('... i d, ... j d -> ... i j', q, k)
|
68 |
+
|
69 |
+
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
70 |
+
attn = sim.softmax(dim = -1)
|
71 |
+
|
72 |
+
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
73 |
+
out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
|
74 |
+
return self.to_out(out)
|
75 |
+
|
76 |
+
class PerceiverResampler(nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
*,
|
80 |
+
dim,
|
81 |
+
depth,
|
82 |
+
dim_head = 64,
|
83 |
+
heads = 8,
|
84 |
+
num_latents = 64,
|
85 |
+
num_time_embeds = 4,
|
86 |
+
ff_mult = 4,
|
87 |
+
inp_dim=None,
|
88 |
+
):
|
89 |
+
super().__init__()
|
90 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
91 |
+
self.time_pos_emb = nn.Parameter(torch.randn(num_time_embeds, 1, dim))
|
92 |
+
if inp_dim is not None:
|
93 |
+
self.inp_linear = nn.Linear(inp_dim, dim, bias=False)
|
94 |
+
else:
|
95 |
+
self.inp_linear = None
|
96 |
+
|
97 |
+
self.layers = nn.ModuleList([])
|
98 |
+
for _ in range(depth):
|
99 |
+
self.layers.append(nn.ModuleList([
|
100 |
+
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
|
101 |
+
FeedForward(dim = dim, mult = ff_mult)
|
102 |
+
]))
|
103 |
+
|
104 |
+
self.norm = nn.LayerNorm(dim)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
if x.ndim == 3:
|
108 |
+
x = rearrange(x, 'b n d -> b 1 n d')
|
109 |
+
|
110 |
+
if self.inp_linear is not None:
|
111 |
+
x = self.inp_linear(x)
|
112 |
+
|
113 |
+
times = x.shape[1]
|
114 |
+
x = x + self.time_pos_emb[:times]
|
115 |
+
|
116 |
+
latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])
|
117 |
+
|
118 |
+
for attn, ff in self.layers:
|
119 |
+
latents = attn(x, latents) + latents
|
120 |
+
latents = ff(latents) + latents
|
121 |
+
|
122 |
+
return self.norm(latents)
|
123 |
+
|
124 |
+
# gated cross attention
|
125 |
+
|
126 |
+
class MaskedCrossAttention(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
*,
|
130 |
+
dim,
|
131 |
+
dim_head = 64,
|
132 |
+
heads = 8,
|
133 |
+
only_attend_immediate_media = True
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
self.scale = dim_head ** -0.5
|
137 |
+
self.heads = heads
|
138 |
+
inner_dim = dim_head * heads
|
139 |
+
|
140 |
+
self.norm = nn.LayerNorm(dim)
|
141 |
+
|
142 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
143 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
144 |
+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
145 |
+
|
146 |
+
# whether for text to only attend to immediate preceding image, or all images
|
147 |
+
|
148 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
x,
|
153 |
+
media,
|
154 |
+
media_locations = None
|
155 |
+
):
|
156 |
+
b, t, m = media.shape[:3]
|
157 |
+
h = self.heads
|
158 |
+
|
159 |
+
x = self.norm(x)
|
160 |
+
|
161 |
+
q = self.to_q(x)
|
162 |
+
media = rearrange(media, 'b t n d -> b (t n) d')
|
163 |
+
|
164 |
+
k, v = self.to_kv(media).chunk(2, dim = -1)
|
165 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
166 |
+
|
167 |
+
q = q * self.scale
|
168 |
+
|
169 |
+
sim = einsum('... i d, ... j d -> ... i j', q, k)
|
170 |
+
|
171 |
+
if exists(media_locations):
|
172 |
+
text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
|
173 |
+
media_time = torch.arange(t, device = x.device) + 1
|
174 |
+
|
175 |
+
# text time must equal media time if only attending to most immediate image
|
176 |
+
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
|
177 |
+
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
|
178 |
+
|
179 |
+
text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
|
180 |
+
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
|
181 |
+
|
182 |
+
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
183 |
+
attn = sim.softmax(dim = -1)
|
184 |
+
|
185 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
186 |
+
# any text without a preceding media needs to have attention zeroed out
|
187 |
+
text_without_media_mask = text_time == 0
|
188 |
+
text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
|
189 |
+
attn.masked_fill(text_without_media_mask, 0.)
|
190 |
+
|
191 |
+
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
192 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
193 |
+
return self.to_out(out)
|
194 |
+
|
195 |
+
class GatedCrossAttentionBlock(nn.Module):
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
*,
|
199 |
+
dim,
|
200 |
+
dim_head = 64,
|
201 |
+
heads = 8,
|
202 |
+
ff_mult = 4,
|
203 |
+
only_attend_immediate_media = True
|
204 |
+
):
|
205 |
+
super().__init__()
|
206 |
+
self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
|
207 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.]))
|
208 |
+
|
209 |
+
self.ff = FeedForward(dim, mult = ff_mult)
|
210 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.]))
|
211 |
+
|
212 |
+
def forward(
|
213 |
+
self,
|
214 |
+
x,
|
215 |
+
media, # media tensor, encoded by perceiver resample - (batch, time, latents, dim)
|
216 |
+
media_locations = None # boolean tensor indicating positions of media - (batch, sequence)
|
217 |
+
):
|
218 |
+
x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
|
219 |
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
220 |
+
return x
|