jadechoghari
commited on
Commit
•
04c68a2
1
Parent(s):
4de0c6c
Create mv_attention.py
Browse files- unet/mv_attention.py +367 -0
unet/mv_attention.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn, einsum
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
def conv_nd(dims, *args, **kwargs):
|
9 |
+
"""
|
10 |
+
Create a 1D, 2D, or 3D convolution module.
|
11 |
+
"""
|
12 |
+
if dims == 1:
|
13 |
+
return nn.Conv1d(*args, **kwargs)
|
14 |
+
elif dims == 2:
|
15 |
+
return nn.Conv2d(*args, **kwargs)
|
16 |
+
elif dims == 3:
|
17 |
+
return nn.Conv3d(*args, **kwargs)
|
18 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
19 |
+
|
20 |
+
|
21 |
+
from .attention import *
|
22 |
+
|
23 |
+
try:
|
24 |
+
import xformers
|
25 |
+
import xformers.ops
|
26 |
+
XFORMERS_IS_AVAILBLE = True
|
27 |
+
except:
|
28 |
+
XFORMERS_IS_AVAILBLE = False
|
29 |
+
print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
|
30 |
+
|
31 |
+
|
32 |
+
class SPADAttention(nn.Module):
|
33 |
+
"""Uses xformers to implement efficient epipolar masking for cross-attention between views."""
|
34 |
+
|
35 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
36 |
+
super().__init__()
|
37 |
+
inner_dim = dim_head * heads
|
38 |
+
context_dim = default(context_dim, query_dim)
|
39 |
+
|
40 |
+
self.heads = heads
|
41 |
+
self.dim_head = dim_head
|
42 |
+
|
43 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
44 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
45 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
46 |
+
|
47 |
+
self.to_out = nn.Sequential(
|
48 |
+
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
49 |
+
)
|
50 |
+
self.attention_op: Optional[Any] = None
|
51 |
+
|
52 |
+
def forward(self, x, context=None, mask=None, views=None):
|
53 |
+
q = self.to_q(x)
|
54 |
+
context = default(context, x)
|
55 |
+
k = self.to_k(context)
|
56 |
+
v = self.to_v(context)
|
57 |
+
|
58 |
+
b, _, _ = q.shape
|
59 |
+
|
60 |
+
# epipolar mask
|
61 |
+
if mask is not None:
|
62 |
+
mask = mask.unsqueeze(1)
|
63 |
+
mask_shape = (q.shape[-2], k.shape[-2])
|
64 |
+
|
65 |
+
# interpolate epipolar mask to match downsampled unet branch
|
66 |
+
mask = (
|
67 |
+
F.interpolate(mask.to(torch.uint8), size=mask_shape).bool().squeeze(1)
|
68 |
+
)
|
69 |
+
|
70 |
+
# repeat mask for each attention head
|
71 |
+
mask = (
|
72 |
+
mask.unsqueeze(1)
|
73 |
+
.repeat(1, self.heads, 1, 1)
|
74 |
+
.reshape(b * self.heads, *mask.shape[-2:])
|
75 |
+
)
|
76 |
+
|
77 |
+
q, k, v = map(
|
78 |
+
lambda t: t.unsqueeze(3)
|
79 |
+
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
80 |
+
.permute(0, 2, 1, 3)
|
81 |
+
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
82 |
+
.contiguous(),
|
83 |
+
(q, k, v),
|
84 |
+
)
|
85 |
+
|
86 |
+
with torch.autocast(enabled=False, device_type="cuda"):
|
87 |
+
q, k, v = q.float(), k.float(), v.float()
|
88 |
+
|
89 |
+
mask_inf = 1e9
|
90 |
+
fmask = None
|
91 |
+
if mask is not None:
|
92 |
+
# convert to attention bias
|
93 |
+
fmask = mask.float()
|
94 |
+
fmask[fmask == 0] = -mask_inf
|
95 |
+
fmask[fmask == 1] = 0
|
96 |
+
|
97 |
+
# actually compute the attention, what we cannot get enough of
|
98 |
+
# Scaled dot-product attention implementation instead of xformers
|
99 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dim_head)
|
100 |
+
if fmask is not None:
|
101 |
+
attn_scores += fmask
|
102 |
+
|
103 |
+
attn_weights = torch.softmax(attn_scores, dim=-1)
|
104 |
+
out = torch.matmul(attn_weights, v)
|
105 |
+
|
106 |
+
out = (
|
107 |
+
out.unsqueeze(0)
|
108 |
+
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
109 |
+
.permute(0, 2, 1, 3)
|
110 |
+
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
111 |
+
)
|
112 |
+
|
113 |
+
# no nans
|
114 |
+
if out.isnan().any():
|
115 |
+
breakpoint()
|
116 |
+
|
117 |
+
# cleanup
|
118 |
+
del q, k, v
|
119 |
+
return self.to_out(out)
|
120 |
+
|
121 |
+
|
122 |
+
class SPADTransformerBlock(nn.Module):
|
123 |
+
"""Modified SPAD transformer block that enables spatially aware cross-attention."""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dim,
|
128 |
+
n_heads,
|
129 |
+
d_head,
|
130 |
+
dropout=0.0,
|
131 |
+
context_dim=None,
|
132 |
+
gated_ff=True,
|
133 |
+
checkpoint=True,
|
134 |
+
disable_self_attn=False,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
attn_cls = SPADAttention
|
138 |
+
self.disable_self_attn = disable_self_attn
|
139 |
+
self.attn1 = attn_cls(
|
140 |
+
query_dim=dim,
|
141 |
+
heads=n_heads,
|
142 |
+
dim_head=d_head,
|
143 |
+
dropout=dropout,
|
144 |
+
context_dim=context_dim if self.disable_self_attn else None,
|
145 |
+
) # is a self-attention if not self.disable_self_attn
|
146 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
147 |
+
self.attn2 = attn_cls(
|
148 |
+
query_dim=dim,
|
149 |
+
context_dim=context_dim,
|
150 |
+
heads=n_heads,
|
151 |
+
dim_head=d_head,
|
152 |
+
dropout=dropout,
|
153 |
+
) # is self-attn if context is none
|
154 |
+
self.norm1 = nn.LayerNorm(dim)
|
155 |
+
self.norm2 = nn.LayerNorm(dim)
|
156 |
+
self.norm3 = nn.LayerNorm(dim)
|
157 |
+
self.checkpoint = checkpoint
|
158 |
+
|
159 |
+
def forward(self, x, context=None, mask=None):
|
160 |
+
return checkpoint(
|
161 |
+
self.manystream_forward,
|
162 |
+
(x, context, mask),
|
163 |
+
self.parameters(),
|
164 |
+
self.checkpoint,
|
165 |
+
)
|
166 |
+
|
167 |
+
def manystream_forward(self, x, context=None, mask=None):
|
168 |
+
assert not self.disable_self_attn
|
169 |
+
# x: [n, v, h*w, c]
|
170 |
+
# context: [n, v, seq_len, d]
|
171 |
+
n, v = x.shape[:2]
|
172 |
+
|
173 |
+
# self-attention (between views) with 3d mask
|
174 |
+
x = rearrange(x, "n v hw c -> n (v hw) c")
|
175 |
+
x = self.attn1(self.norm1(x), context=None, mask=mask, views=v) + x
|
176 |
+
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
|
177 |
+
|
178 |
+
# cross-attention (to individual views)
|
179 |
+
x = rearrange(x, "n v hw c -> (n v) hw c")
|
180 |
+
context = rearrange(context, "n v seq d -> (n v) seq d")
|
181 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
182 |
+
x = self.ff(self.norm3(x)) + x
|
183 |
+
x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
|
184 |
+
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class SPADTransformer(nn.Module):
|
189 |
+
"""Spatial Transformer block with post init to add cross attn."""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
in_channels,
|
194 |
+
n_heads,
|
195 |
+
d_head,
|
196 |
+
depth=1,
|
197 |
+
dropout=0.0,
|
198 |
+
context_dim=None,
|
199 |
+
disable_self_attn=False,
|
200 |
+
use_linear=False, # 2.1 vs 1.5 difference
|
201 |
+
use_checkpoint=True,
|
202 |
+
):
|
203 |
+
super().__init__()
|
204 |
+
if exists(context_dim) and not isinstance(context_dim, list):
|
205 |
+
context_dim = [context_dim]
|
206 |
+
self.in_channels = in_channels
|
207 |
+
inner_dim = n_heads * d_head
|
208 |
+
self.norm = Normalize(in_channels)
|
209 |
+
if not use_linear:
|
210 |
+
self.proj_in = nn.Conv2d(
|
211 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
215 |
+
|
216 |
+
self.transformer_blocks = nn.ModuleList(
|
217 |
+
[
|
218 |
+
SPADTransformerBlock(
|
219 |
+
inner_dim,
|
220 |
+
n_heads,
|
221 |
+
d_head,
|
222 |
+
dropout=dropout,
|
223 |
+
context_dim=context_dim[d],
|
224 |
+
disable_self_attn=disable_self_attn,
|
225 |
+
checkpoint=use_checkpoint,
|
226 |
+
)
|
227 |
+
for d in range(depth)
|
228 |
+
]
|
229 |
+
)
|
230 |
+
if not use_linear:
|
231 |
+
self.proj_out = zero_module(
|
232 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
236 |
+
self.use_linear = use_linear
|
237 |
+
|
238 |
+
# modify conv layers incorporate plucker coordinates
|
239 |
+
self.post_init()
|
240 |
+
|
241 |
+
def post_init(self):
|
242 |
+
assert getattr(self, "post_intialized", False) is False, "already modified!"
|
243 |
+
|
244 |
+
# inflate input conv block to attach plucker coordinates
|
245 |
+
conv_block = self.proj_in
|
246 |
+
conv_params = {
|
247 |
+
k: getattr(conv_block, k)
|
248 |
+
for k in [
|
249 |
+
"in_channels",
|
250 |
+
"out_channels",
|
251 |
+
"kernel_size",
|
252 |
+
"stride",
|
253 |
+
"padding",
|
254 |
+
]
|
255 |
+
}
|
256 |
+
conv_params["in_channels"] += 6
|
257 |
+
conv_params["dims"] = 2
|
258 |
+
conv_params["device"] = conv_block.weight.device
|
259 |
+
|
260 |
+
# copy original weights for input conv block
|
261 |
+
inflated_proj_in = conv_nd(**conv_params)
|
262 |
+
inp_weight = conv_block.weight.data
|
263 |
+
feat_shape = inp_weight.shape
|
264 |
+
|
265 |
+
# intialize new weights for plucker coordinates as zeros
|
266 |
+
feat_weight = torch.zeros(
|
267 |
+
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device
|
268 |
+
)
|
269 |
+
|
270 |
+
# assemble new weights and bias
|
271 |
+
inflated_proj_in.weight.data.copy_(
|
272 |
+
torch.cat([inp_weight, feat_weight], dim=1)
|
273 |
+
)
|
274 |
+
inflated_proj_in.bias.data.copy_(conv_block.bias.data)
|
275 |
+
self.proj_in = inflated_proj_in
|
276 |
+
self.post_intialized = True
|
277 |
+
|
278 |
+
def forward(self, x, context=None):
|
279 |
+
return self.spad_forward(x, context=context)
|
280 |
+
|
281 |
+
def spad_forward(self, x, context=None):
|
282 |
+
"""
|
283 |
+
x: tensor of shape [n, v, c (4), h (32), w (32)]
|
284 |
+
context: list of [text_emb, epipolar_mask, plucker_coords]
|
285 |
+
- text_emb: tensor of shape [n, v, seq_len (77), dim (768)]
|
286 |
+
- epipolar_mask: bool tensor of shape [n, v, seq_len (32*32), seq_len (32*32)]
|
287 |
+
- plucker_coords: tensor of shape [n, v, dim (6), h (32), w (32)]
|
288 |
+
"""
|
289 |
+
|
290 |
+
n_objects, n_views, c, h, w = x.shape
|
291 |
+
x_in = x
|
292 |
+
|
293 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
294 |
+
context, plucker = context[:-1], context[-1]
|
295 |
+
context = [context]
|
296 |
+
|
297 |
+
x = rearrange(x, "n v c h w -> (n v) c h w")
|
298 |
+
x = self.norm(x)
|
299 |
+
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
|
300 |
+
|
301 |
+
# run input projection
|
302 |
+
if not self.use_linear:
|
303 |
+
# interpolate plucker to match x
|
304 |
+
plucker = rearrange(plucker, "n v c h w -> (n v) c h w")
|
305 |
+
plucker_interpolated = F.interpolate(
|
306 |
+
plucker, size=x.shape[-2:], align_corners=False, mode="bilinear"
|
307 |
+
)
|
308 |
+
plucker_interpolated = rearrange(
|
309 |
+
plucker_interpolated, "(n v) c h w -> n v c h w", v=n_views
|
310 |
+
)
|
311 |
+
|
312 |
+
# concat plucker to x
|
313 |
+
x = torch.cat([x, plucker_interpolated], dim=2)
|
314 |
+
x = rearrange(x, "n v c h w -> (n v) c h w")
|
315 |
+
x = self.proj_in(x)
|
316 |
+
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
|
317 |
+
|
318 |
+
x = rearrange(x, "n v c h w -> n v (h w) c").contiguous()
|
319 |
+
|
320 |
+
if self.use_linear:
|
321 |
+
x = rearrange(x, "n v x c -> (n v) x c")
|
322 |
+
x = self.proj_in(x)
|
323 |
+
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
|
324 |
+
|
325 |
+
# run the transformer blocks
|
326 |
+
for i, block in enumerate(self.transformer_blocks):
|
327 |
+
_context = context[i]
|
328 |
+
mask = None
|
329 |
+
if isinstance(_context, (list, tuple)):
|
330 |
+
try:
|
331 |
+
_context, mask = _context
|
332 |
+
except:
|
333 |
+
_context = _context[0]
|
334 |
+
x = block(x, context=_context, mask=mask)
|
335 |
+
|
336 |
+
if x.isnan().any():
|
337 |
+
breakpoint()
|
338 |
+
|
339 |
+
# run output projection
|
340 |
+
if self.use_linear:
|
341 |
+
x = rearrange(x, "n v x c -> (n v) x c")
|
342 |
+
x = self.proj_out(x)
|
343 |
+
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
|
344 |
+
|
345 |
+
x = rearrange(x, "n v (h w) c -> n v c h w", h=h, w=w).contiguous()
|
346 |
+
|
347 |
+
if not self.use_linear:
|
348 |
+
x = rearrange(x, "n v c h w -> (n v) c h w")
|
349 |
+
x = self.proj_out(x)
|
350 |
+
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
|
351 |
+
|
352 |
+
return x + x_in
|
353 |
+
|
354 |
+
|
355 |
+
if __name__ == "__main__":
|
356 |
+
spt_post = SPADTransformer(320, 8, 40, depth=1, context_dim=768).cuda()
|
357 |
+
|
358 |
+
n_objects, n_views = 2, 4
|
359 |
+
x = torch.randn(2, 4, 320, 32, 32).cuda()
|
360 |
+
context = [
|
361 |
+
torch.randn(n_objects, n_views, 77, 768).cuda(),
|
362 |
+
torch.ones(
|
363 |
+
n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool
|
364 |
+
).cuda(),
|
365 |
+
torch.randn(n_objects, n_views, 6, 32, 32).cuda(),
|
366 |
+
]
|
367 |
+
x_post = spt_post(x, context=context)
|