spad / unet /mv_attention.py
jadechoghari's picture
Update unet/mv_attention.py
efc43f7 verified
raw
history blame
12.1 kB
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
from .attention import *
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
class SPADAttention(nn.Module):
"""Uses xformers to implement efficient epipolar masking for cross-attention between views."""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None, views=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
# epipolar mask
if mask is not None:
mask = mask.unsqueeze(1)
mask_shape = (q.shape[-2], k.shape[-2])
# interpolate epipolar mask to match downsampled unet branch
mask = (
F.interpolate(mask.to(torch.uint8), size=mask_shape).bool().squeeze(1)
)
# repeat mask for each attention head
mask = (
mask.unsqueeze(1)
.repeat(1, self.heads, 1, 1)
.reshape(b * self.heads, *mask.shape[-2:])
)
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
with torch.autocast(enabled=False, device_type="cuda"):
q, k, v = q.float(), k.float(), v.float()
mask_inf = 1e9
fmask = None
if mask is not None:
# convert to attention bias
fmask = mask.float()
fmask[fmask == 0] = -mask_inf
fmask[fmask == 1] = 0
# actually compute the attention, what we cannot get enough of
# Scaled dot-product attention implementation instead of xformers
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dim_head)
if fmask is not None:
attn_scores += fmask
attn_weights = torch.softmax(attn_scores, dim=-1)
out = torch.matmul(attn_weights, v)
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
# no nans
if out.isnan().any():
breakpoint()
# cleanup
del q, k, v
return self.to_out(out)
class SPADTransformerBlock(nn.Module):
"""Modified SPAD transformer block that enables spatially aware cross-attention."""
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
):
super().__init__()
attn_cls = SPADAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, mask=None):
return checkpoint(
self.manystream_forward,
(x, context, mask),
self.parameters(),
self.checkpoint,
)
def manystream_forward(self, x, context=None, mask=None):
assert not self.disable_self_attn
# x: [n, v, h*w, c]
# context: [n, v, seq_len, d]
n, v = x.shape[:2]
# self-attention (between views) with 3d mask
x = rearrange(x, "n v hw c -> n (v hw) c")
x = self.attn1(self.norm1(x), context=None, mask=mask, views=v) + x
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
# cross-attention (to individual views)
x = rearrange(x, "n v hw c -> (n v) hw c")
context = rearrange(context, "n v seq d -> (n v) seq d")
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
return x
class SPADTransformer(nn.Module):
"""Spatial Transformer block with post init to add cross attn."""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False, # 2.1 vs 1.5 difference
use_checkpoint=True,
):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
SPADTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
)
for d in range(depth)
]
)
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
# modify conv layers incorporate plucker coordinates
self.post_init()
def post_init(self):
assert getattr(self, "post_intialized", False) is False, "already modified!"
# inflate input conv block to attach plucker coordinates
conv_block = self.proj_in
conv_params = {
k: getattr(conv_block, k)
for k in [
"in_channels",
"out_channels",
"kernel_size",
"stride",
"padding",
]
}
conv_params["in_channels"] += 6
conv_params["dims"] = 2
conv_params["device"] = conv_block.weight.device
# copy original weights for input conv block
inflated_proj_in = conv_nd(**conv_params)
inp_weight = conv_block.weight.data
feat_shape = inp_weight.shape
# intialize new weights for plucker coordinates as zeros
feat_weight = torch.zeros(
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device
)
# assemble new weights and bias
inflated_proj_in.weight.data.copy_(
torch.cat([inp_weight, feat_weight], dim=1)
)
inflated_proj_in.bias.data.copy_(conv_block.bias.data)
self.proj_in = inflated_proj_in
self.post_intialized = True
def forward(self, x, context=None):
return self.spad_forward(x, context=context)
def spad_forward(self, x, context=None):
"""
x: tensor of shape [n, v, c (4), h (32), w (32)]
context: list of [text_emb, epipolar_mask, plucker_coords]
- text_emb: tensor of shape [n, v, seq_len (77), dim (768)]
- epipolar_mask: bool tensor of shape [n, v, seq_len (32*32), seq_len (32*32)]
- plucker_coords: tensor of shape [n, v, dim (6), h (32), w (32)]
"""
n_objects, n_views, c, h, w = x.shape
x_in = x
# note: if no context is given, cross-attention defaults to self-attention
context, plucker = context[:-1], context[-1]
context = [context]
x = rearrange(x, "n v c h w -> (n v) c h w")
x = self.norm(x)
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
# run input projection
if not self.use_linear:
# interpolate plucker to match x
plucker = rearrange(plucker, "n v c h w -> (n v) c h w")
plucker_interpolated = F.interpolate(
plucker, size=x.shape[-2:], align_corners=False, mode="bilinear"
)
plucker_interpolated = rearrange(
plucker_interpolated, "(n v) c h w -> n v c h w", v=n_views
)
# concat plucker to x
x = torch.cat([x, plucker_interpolated], dim=2)
x = rearrange(x, "n v c h w -> (n v) c h w")
x = self.proj_in(x)
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
x = rearrange(x, "n v c h w -> n v (h w) c").contiguous()
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_in(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
# run the transformer blocks
for i, block in enumerate(self.transformer_blocks):
_context = context[i]
mask = None
if isinstance(_context, (list, tuple)):
try:
_context, mask = _context
except:
_context = _context[0]
x = block(x, context=_context, mask=mask)
if x.isnan().any():
breakpoint()
# run output projection
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_out(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
x = rearrange(x, "n v (h w) c -> n v c h w", h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, "n v c h w -> (n v) c h w")
x = self.proj_out(x)
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views)
return x + x_in
if __name__ == "__main__":
spt_post = SPADTransformer(320, 8, 40, depth=1, context_dim=768).cuda()
n_objects, n_views = 2, 4
x = torch.randn(2, 4, 320, 32, 32).cuda()
context = [
torch.randn(n_objects, n_views, 77, 768).cuda(),
torch.ones(
n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool
).cuda(),
torch.randn(n_objects, n_views, 6, 32, 32).cuda(),
]
x_post = spt_post(x, context=context)