simpleParadox commited on
Commit
3a7fb46
1 Parent(s): 3d7108b

Upload flamingo_pytorch.py

Browse files
Files changed (1) hide show
  1. flamingo_pytorch.py +220 -0
flamingo_pytorch.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+ import torch.nn.functional as F
4
+
5
+ from einops import rearrange, repeat
6
+ from einops_exts import rearrange_many, repeat_many
7
+ import pdb
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def FeedForward(dim, mult = 4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias = False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias = False)
20
+ )
21
+
22
+ class PerceiverAttention(nn.Module):
23
+ def __init__(
24
+ self,
25
+ *,
26
+ dim,
27
+ dim_head = 64,
28
+ heads = 8
29
+ ):
30
+ super().__init__()
31
+ self.scale = dim_head ** -0.5
32
+ self.heads = heads
33
+ inner_dim = dim_head * heads
34
+
35
+ self.norm_media = nn.LayerNorm(dim)
36
+ self.norm_latents = nn.LayerNorm(dim)
37
+
38
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
39
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
40
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
41
+
42
+ def forward(self, x, latents):
43
+ """
44
+ einstein notation
45
+ b - batch
46
+ t - time
47
+ n - sequence
48
+ d - dimension
49
+ """
50
+ x = self.norm_media(x)
51
+ latents = self.norm_latents(latents)
52
+
53
+ b, m, h = *x.shape[:2], self.heads
54
+
55
+ q = self.to_q(latents)
56
+
57
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
58
+ kv_input = torch.cat((x, latents), dim = -2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
60
+
61
+ q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h = h)
62
+
63
+ q = q * self.scale
64
+
65
+ # attention
66
+
67
+ sim = einsum('... i d, ... j d -> ... i j', q, k)
68
+
69
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
70
+ attn = sim.softmax(dim = -1)
71
+
72
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
73
+ out = rearrange(out, 'b h t n d -> b t n (h d)', h = h)
74
+ return self.to_out(out)
75
+
76
+ class PerceiverResampler(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth,
82
+ dim_head = 64,
83
+ heads = 8,
84
+ num_latents = 64,
85
+ num_time_embeds = 4,
86
+ ff_mult = 4,
87
+ inp_dim=None,
88
+ ):
89
+ super().__init__()
90
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
91
+ self.time_pos_emb = nn.Parameter(torch.randn(num_time_embeds, 1, dim))
92
+ if inp_dim is not None:
93
+ self.inp_linear = nn.Linear(inp_dim, dim, bias=False)
94
+ else:
95
+ self.inp_linear = None
96
+
97
+ self.layers = nn.ModuleList([])
98
+ for _ in range(depth):
99
+ self.layers.append(nn.ModuleList([
100
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads),
101
+ FeedForward(dim = dim, mult = ff_mult)
102
+ ]))
103
+
104
+ self.norm = nn.LayerNorm(dim)
105
+
106
+ def forward(self, x):
107
+ if x.ndim == 3:
108
+ x = rearrange(x, 'b n d -> b 1 n d')
109
+
110
+ if self.inp_linear is not None:
111
+ x = self.inp_linear(x)
112
+
113
+ times = x.shape[1]
114
+ x = x + self.time_pos_emb[:times]
115
+
116
+ latents = repeat(self.latents, 'n d -> b m n d', b = x.shape[0], m = x.shape[1])
117
+
118
+ for attn, ff in self.layers:
119
+ latents = attn(x, latents) + latents
120
+ latents = ff(latents) + latents
121
+
122
+ return self.norm(latents)
123
+
124
+ # gated cross attention
125
+
126
+ class MaskedCrossAttention(nn.Module):
127
+ def __init__(
128
+ self,
129
+ *,
130
+ dim,
131
+ dim_head = 64,
132
+ heads = 8,
133
+ only_attend_immediate_media = True
134
+ ):
135
+ super().__init__()
136
+ self.scale = dim_head ** -0.5
137
+ self.heads = heads
138
+ inner_dim = dim_head * heads
139
+
140
+ self.norm = nn.LayerNorm(dim)
141
+
142
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
143
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
144
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
145
+
146
+ # whether for text to only attend to immediate preceding image, or all images
147
+
148
+ self.only_attend_immediate_media = only_attend_immediate_media
149
+
150
+ def forward(
151
+ self,
152
+ x,
153
+ media,
154
+ media_locations = None
155
+ ):
156
+ b, t, m = media.shape[:3]
157
+ h = self.heads
158
+
159
+ x = self.norm(x)
160
+
161
+ q = self.to_q(x)
162
+ media = rearrange(media, 'b t n d -> b (t n) d')
163
+
164
+ k, v = self.to_kv(media).chunk(2, dim = -1)
165
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
166
+
167
+ q = q * self.scale
168
+
169
+ sim = einsum('... i d, ... j d -> ... i j', q, k)
170
+
171
+ if exists(media_locations):
172
+ text_time = media_locations.cumsum(dim = -1) # at each boolean of True, increment the time counter (relative to media time)
173
+ media_time = torch.arange(t, device = x.device) + 1
174
+
175
+ # text time must equal media time if only attending to most immediate image
176
+ # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
177
+ mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
178
+
179
+ text_to_media_mask = mask_op(rearrange(text_time, 'b i -> b 1 i 1'), repeat(media_time, 'j -> 1 1 1 (j m)', m = m))
180
+ sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
181
+
182
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
183
+ attn = sim.softmax(dim = -1)
184
+
185
+ if exists(media_locations) and self.only_attend_immediate_media:
186
+ # any text without a preceding media needs to have attention zeroed out
187
+ text_without_media_mask = text_time == 0
188
+ text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
189
+ attn.masked_fill(text_without_media_mask, 0.)
190
+
191
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
192
+ out = rearrange(out, 'b h n d -> b n (h d)')
193
+ return self.to_out(out)
194
+
195
+ class GatedCrossAttentionBlock(nn.Module):
196
+ def __init__(
197
+ self,
198
+ *,
199
+ dim,
200
+ dim_head = 64,
201
+ heads = 8,
202
+ ff_mult = 4,
203
+ only_attend_immediate_media = True
204
+ ):
205
+ super().__init__()
206
+ self.attn = MaskedCrossAttention(dim = dim, dim_head = dim_head, heads = heads, only_attend_immediate_media = only_attend_immediate_media)
207
+ self.attn_gate = nn.Parameter(torch.tensor([0.]))
208
+
209
+ self.ff = FeedForward(dim, mult = ff_mult)
210
+ self.ff_gate = nn.Parameter(torch.tensor([0.]))
211
+
212
+ def forward(
213
+ self,
214
+ x,
215
+ media, # media tensor, encoded by perceiver resample - (batch, time, latents, dim)
216
+ media_locations = None # boolean tensor indicating positions of media - (batch, sequence)
217
+ ):
218
+ x = self.attn(x, media, media_locations = media_locations) * self.attn_gate.tanh() + x
219
+ x = self.ff(x) * self.ff_gate.tanh() + x
220
+ return x