jadechoghari commited on
Commit
04c68a2
1 Parent(s): 4de0c6c

Create mv_attention.py

Browse files
Files changed (1) hide show
  1. unet/mv_attention.py +367 -0
unet/mv_attention.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ def conv_nd(dims, *args, **kwargs):
9
+ """
10
+ Create a 1D, 2D, or 3D convolution module.
11
+ """
12
+ if dims == 1:
13
+ return nn.Conv1d(*args, **kwargs)
14
+ elif dims == 2:
15
+ return nn.Conv2d(*args, **kwargs)
16
+ elif dims == 3:
17
+ return nn.Conv3d(*args, **kwargs)
18
+ raise ValueError(f"unsupported dimensions: {dims}")
19
+
20
+
21
+ from .attention import *
22
+
23
+ try:
24
+ import xformers
25
+ import xformers.ops
26
+ XFORMERS_IS_AVAILBLE = True
27
+ except:
28
+ XFORMERS_IS_AVAILBLE = False
29
+ print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
30
+
31
+
32
+ class SPADAttention(nn.Module):
33
+ """Uses xformers to implement efficient epipolar masking for cross-attention between views."""
34
+
35
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
36
+ super().__init__()
37
+ inner_dim = dim_head * heads
38
+ context_dim = default(context_dim, query_dim)
39
+
40
+ self.heads = heads
41
+ self.dim_head = dim_head
42
+
43
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
44
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
45
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
46
+
47
+ self.to_out = nn.Sequential(
48
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
49
+ )
50
+ self.attention_op: Optional[Any] = None
51
+
52
+ def forward(self, x, context=None, mask=None, views=None):
53
+ q = self.to_q(x)
54
+ context = default(context, x)
55
+ k = self.to_k(context)
56
+ v = self.to_v(context)
57
+
58
+ b, _, _ = q.shape
59
+
60
+ # epipolar mask
61
+ if mask is not None:
62
+ mask = mask.unsqueeze(1)
63
+ mask_shape = (q.shape[-2], k.shape[-2])
64
+
65
+ # interpolate epipolar mask to match downsampled unet branch
66
+ mask = (
67
+ F.interpolate(mask.to(torch.uint8), size=mask_shape).bool().squeeze(1)
68
+ )
69
+
70
+ # repeat mask for each attention head
71
+ mask = (
72
+ mask.unsqueeze(1)
73
+ .repeat(1, self.heads, 1, 1)
74
+ .reshape(b * self.heads, *mask.shape[-2:])
75
+ )
76
+
77
+ q, k, v = map(
78
+ lambda t: t.unsqueeze(3)
79
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
80
+ .permute(0, 2, 1, 3)
81
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
82
+ .contiguous(),
83
+ (q, k, v),
84
+ )
85
+
86
+ with torch.autocast(enabled=False, device_type="cuda"):
87
+ q, k, v = q.float(), k.float(), v.float()
88
+
89
+ mask_inf = 1e9
90
+ fmask = None
91
+ if mask is not None:
92
+ # convert to attention bias
93
+ fmask = mask.float()
94
+ fmask[fmask == 0] = -mask_inf
95
+ fmask[fmask == 1] = 0
96
+
97
+ # actually compute the attention, what we cannot get enough of
98
+ # Scaled dot-product attention implementation instead of xformers
99
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dim_head)
100
+ if fmask is not None:
101
+ attn_scores += fmask
102
+
103
+ attn_weights = torch.softmax(attn_scores, dim=-1)
104
+ out = torch.matmul(attn_weights, v)
105
+
106
+ out = (
107
+ out.unsqueeze(0)
108
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
109
+ .permute(0, 2, 1, 3)
110
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
111
+ )
112
+
113
+ # no nans
114
+ if out.isnan().any():
115
+ breakpoint()
116
+
117
+ # cleanup
118
+ del q, k, v
119
+ return self.to_out(out)
120
+
121
+
122
+ class SPADTransformerBlock(nn.Module):
123
+ """Modified SPAD transformer block that enables spatially aware cross-attention."""
124
+
125
+ def __init__(
126
+ self,
127
+ dim,
128
+ n_heads,
129
+ d_head,
130
+ dropout=0.0,
131
+ context_dim=None,
132
+ gated_ff=True,
133
+ checkpoint=True,
134
+ disable_self_attn=False,
135
+ ):
136
+ super().__init__()
137
+ attn_cls = SPADAttention
138
+ self.disable_self_attn = disable_self_attn
139
+ self.attn1 = attn_cls(
140
+ query_dim=dim,
141
+ heads=n_heads,
142
+ dim_head=d_head,
143
+ dropout=dropout,
144
+ context_dim=context_dim if self.disable_self_attn else None,
145
+ ) # is a self-attention if not self.disable_self_attn
146
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
147
+ self.attn2 = attn_cls(
148
+ query_dim=dim,
149
+ context_dim=context_dim,
150
+ heads=n_heads,
151
+ dim_head=d_head,
152
+ dropout=dropout,
153
+ ) # is self-attn if context is none
154
+ self.norm1 = nn.LayerNorm(dim)
155
+ self.norm2 = nn.LayerNorm(dim)
156
+ self.norm3 = nn.LayerNorm(dim)
157
+ self.checkpoint = checkpoint
158
+
159
+ def forward(self, x, context=None, mask=None):
160
+ return checkpoint(
161
+ self.manystream_forward,
162
+ (x, context, mask),
163
+ self.parameters(),
164
+ self.checkpoint,
165
+ )
166
+
167
+ def manystream_forward(self, x, context=None, mask=None):
168
+ assert not self.disable_self_attn
169
+ # x: [n, v, h*w, c]
170
+ # context: [n, v, seq_len, d]
171
+ n, v = x.shape[:2]
172
+
173
+ # self-attention (between views) with 3d mask
174
+ x = rearrange(x, "n v hw c -> n (v hw) c")
175
+ x = self.attn1(self.norm1(x), context=None, mask=mask, views=v) + x
176
+ x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
177
+
178
+ # cross-attention (to individual views)
179
+ x = rearrange(x, "n v hw c -> (n v) hw c")
180
+ context = rearrange(context, "n v seq d -> (n v) seq d")
181
+ x = self.attn2(self.norm2(x), context=context) + x
182
+ x = self.ff(self.norm3(x)) + x
183
+ x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
184
+
185
+ return x
186
+
187
+
188
+ class SPADTransformer(nn.Module):
189
+ """Spatial Transformer block with post init to add cross attn."""
190
+
191
+ def __init__(
192
+ self,
193
+ in_channels,
194
+ n_heads,
195
+ d_head,
196
+ depth=1,
197
+ dropout=0.0,
198
+ context_dim=None,
199
+ disable_self_attn=False,
200
+ use_linear=False, # 2.1 vs 1.5 difference
201
+ use_checkpoint=True,
202
+ ):
203
+ super().__init__()
204
+ if exists(context_dim) and not isinstance(context_dim, list):
205
+ context_dim = [context_dim]
206
+ self.in_channels = in_channels
207
+ inner_dim = n_heads * d_head
208
+ self.norm = Normalize(in_channels)
209
+ if not use_linear:
210
+ self.proj_in = nn.Conv2d(
211
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
212
+ )
213
+ else:
214
+ self.proj_in = nn.Linear(in_channels, inner_dim)
215
+
216
+ self.transformer_blocks = nn.ModuleList(
217
+ [
218
+ SPADTransformerBlock(
219
+ inner_dim,
220
+ n_heads,
221
+ d_head,
222
+ dropout=dropout,
223
+ context_dim=context_dim[d],
224
+ disable_self_attn=disable_self_attn,
225
+ checkpoint=use_checkpoint,
226
+ )
227
+ for d in range(depth)
228
+ ]
229
+ )
230
+ if not use_linear:
231
+ self.proj_out = zero_module(
232
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
233
+ )
234
+ else:
235
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
236
+ self.use_linear = use_linear
237
+
238
+ # modify conv layers incorporate plucker coordinates
239
+ self.post_init()
240
+
241
+ def post_init(self):
242
+ assert getattr(self, "post_intialized", False) is False, "already modified!"
243
+
244
+ # inflate input conv block to attach plucker coordinates
245
+ conv_block = self.proj_in
246
+ conv_params = {
247
+ k: getattr(conv_block, k)
248
+ for k in [
249
+ "in_channels",
250
+ "out_channels",
251
+ "kernel_size",
252
+ "stride",
253
+ "padding",
254
+ ]
255
+ }
256
+ conv_params["in_channels"] += 6
257
+ conv_params["dims"] = 2
258
+ conv_params["device"] = conv_block.weight.device
259
+
260
+ # copy original weights for input conv block
261
+ inflated_proj_in = conv_nd(**conv_params)
262
+ inp_weight = conv_block.weight.data
263
+ feat_shape = inp_weight.shape
264
+
265
+ # intialize new weights for plucker coordinates as zeros
266
+ feat_weight = torch.zeros(
267
+ (feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device
268
+ )
269
+
270
+ # assemble new weights and bias
271
+ inflated_proj_in.weight.data.copy_(
272
+ torch.cat([inp_weight, feat_weight], dim=1)
273
+ )
274
+ inflated_proj_in.bias.data.copy_(conv_block.bias.data)
275
+ self.proj_in = inflated_proj_in
276
+ self.post_intialized = True
277
+
278
+ def forward(self, x, context=None):
279
+ return self.spad_forward(x, context=context)
280
+
281
+ def spad_forward(self, x, context=None):
282
+ """
283
+ x: tensor of shape [n, v, c (4), h (32), w (32)]
284
+ context: list of [text_emb, epipolar_mask, plucker_coords]
285
+ - text_emb: tensor of shape [n, v, seq_len (77), dim (768)]
286
+ - epipolar_mask: bool tensor of shape [n, v, seq_len (32*32), seq_len (32*32)]
287
+ - plucker_coords: tensor of shape [n, v, dim (6), h (32), w (32)]
288
+ """
289
+
290
+ n_objects, n_views, c, h, w = x.shape
291
+ x_in = x
292
+
293
+ # note: if no context is given, cross-attention defaults to self-attention
294
+ context, plucker = context[:-1], context[-1]
295
+ context = [context]
296
+
297
+ x = rearrange(x, "n v c h w -> (n v) c h w")
298
+ x = self.norm(x)
299
+ x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
300
+
301
+ # run input projection
302
+ if not self.use_linear:
303
+ # interpolate plucker to match x
304
+ plucker = rearrange(plucker, "n v c h w -> (n v) c h w")
305
+ plucker_interpolated = F.interpolate(
306
+ plucker, size=x.shape[-2:], align_corners=False, mode="bilinear"
307
+ )
308
+ plucker_interpolated = rearrange(
309
+ plucker_interpolated, "(n v) c h w -> n v c h w", v=n_views
310
+ )
311
+
312
+ # concat plucker to x
313
+ x = torch.cat([x, plucker_interpolated], dim=2)
314
+ x = rearrange(x, "n v c h w -> (n v) c h w")
315
+ x = self.proj_in(x)
316
+ x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
317
+
318
+ x = rearrange(x, "n v c h w -> n v (h w) c").contiguous()
319
+
320
+ if self.use_linear:
321
+ x = rearrange(x, "n v x c -> (n v) x c")
322
+ x = self.proj_in(x)
323
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
324
+
325
+ # run the transformer blocks
326
+ for i, block in enumerate(self.transformer_blocks):
327
+ _context = context[i]
328
+ mask = None
329
+ if isinstance(_context, (list, tuple)):
330
+ try:
331
+ _context, mask = _context
332
+ except:
333
+ _context = _context[0]
334
+ x = block(x, context=_context, mask=mask)
335
+
336
+ if x.isnan().any():
337
+ breakpoint()
338
+
339
+ # run output projection
340
+ if self.use_linear:
341
+ x = rearrange(x, "n v x c -> (n v) x c")
342
+ x = self.proj_out(x)
343
+ x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
344
+
345
+ x = rearrange(x, "n v (h w) c -> n v c h w", h=h, w=w).contiguous()
346
+
347
+ if not self.use_linear:
348
+ x = rearrange(x, "n v c h w -> (n v) c h w")
349
+ x = self.proj_out(x)
350
+ x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
351
+
352
+ return x + x_in
353
+
354
+
355
+ if __name__ == "__main__":
356
+ spt_post = SPADTransformer(320, 8, 40, depth=1, context_dim=768).cuda()
357
+
358
+ n_objects, n_views = 2, 4
359
+ x = torch.randn(2, 4, 320, 32, 32).cuda()
360
+ context = [
361
+ torch.randn(n_objects, n_views, 77, 768).cuda(),
362
+ torch.ones(
363
+ n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool
364
+ ).cuda(),
365
+ torch.randn(n_objects, n_views, 6, 32, 32).cuda(),
366
+ ]
367
+ x_post = spt_post(x, context=context)