|
from inspect import isfunction |
|
import math |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn, einsum |
|
from einops import rearrange, repeat |
|
|
|
def conv_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D convolution module. |
|
""" |
|
if dims == 1: |
|
return nn.Conv1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.Conv2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.Conv3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
from .attention import * |
|
|
|
try: |
|
import xformers |
|
import xformers.ops |
|
XFORMERS_IS_AVAILBLE = True |
|
except: |
|
XFORMERS_IS_AVAILBLE = False |
|
print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}") |
|
|
|
|
|
class SPADAttention(nn.Module): |
|
"""Uses xformers to implement efficient epipolar masking for cross-attention between views.""" |
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
|
super().__init__() |
|
inner_dim = dim_head * heads |
|
context_dim = default(context_dim, query_dim) |
|
|
|
self.heads = heads |
|
self.dim_head = dim_head |
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) |
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) |
|
) |
|
self.attention_op: Optional[Any] = None |
|
|
|
def forward(self, x, context=None, mask=None, views=None): |
|
q = self.to_q(x) |
|
context = default(context, x) |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
|
|
b, _, _ = q.shape |
|
|
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
mask_shape = (q.shape[-2], k.shape[-2]) |
|
|
|
|
|
mask = ( |
|
F.interpolate(mask.to(torch.uint8), size=mask_shape).bool().squeeze(1) |
|
) |
|
|
|
|
|
mask = ( |
|
mask.unsqueeze(1) |
|
.repeat(1, self.heads, 1, 1) |
|
.reshape(b * self.heads, *mask.shape[-2:]) |
|
) |
|
|
|
q, k, v = map( |
|
lambda t: t.unsqueeze(3) |
|
.reshape(b, t.shape[1], self.heads, self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b * self.heads, t.shape[1], self.dim_head) |
|
.contiguous(), |
|
(q, k, v), |
|
) |
|
|
|
with torch.autocast(enabled=False, device_type="cuda"): |
|
q, k, v = q.float(), k.float(), v.float() |
|
|
|
mask_inf = 1e9 |
|
fmask = None |
|
if mask is not None: |
|
|
|
fmask = mask.float() |
|
fmask[fmask == 0] = -mask_inf |
|
fmask[fmask == 1] = 0 |
|
|
|
|
|
|
|
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dim_head) |
|
if fmask is not None: |
|
attn_scores += fmask |
|
|
|
attn_weights = torch.softmax(attn_scores, dim=-1) |
|
out = torch.matmul(attn_weights, v) |
|
|
|
out = ( |
|
out.unsqueeze(0) |
|
.reshape(b, self.heads, out.shape[1], self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b, out.shape[1], self.heads * self.dim_head) |
|
) |
|
|
|
|
|
if out.isnan().any(): |
|
breakpoint() |
|
|
|
|
|
del q, k, v |
|
return self.to_out(out) |
|
|
|
|
|
class SPADTransformerBlock(nn.Module): |
|
"""Modified SPAD transformer block that enables spatially aware cross-attention.""" |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
n_heads, |
|
d_head, |
|
dropout=0.0, |
|
context_dim=None, |
|
gated_ff=True, |
|
checkpoint=True, |
|
disable_self_attn=False, |
|
): |
|
super().__init__() |
|
attn_cls = SPADAttention |
|
self.disable_self_attn = disable_self_attn |
|
self.attn1 = attn_cls( |
|
query_dim=dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
context_dim=context_dim if self.disable_self_attn else None, |
|
) |
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) |
|
self.attn2 = attn_cls( |
|
query_dim=dim, |
|
context_dim=context_dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.norm3 = nn.LayerNorm(dim) |
|
self.checkpoint = checkpoint |
|
|
|
def forward(self, x, context=None, mask=None): |
|
return checkpoint( |
|
self.manystream_forward, |
|
(x, context, mask), |
|
self.parameters(), |
|
self.checkpoint, |
|
) |
|
|
|
def manystream_forward(self, x, context=None, mask=None): |
|
assert not self.disable_self_attn |
|
|
|
|
|
n, v = x.shape[:2] |
|
|
|
|
|
x = rearrange(x, "n v hw c -> n (v hw) c") |
|
x = self.attn1(self.norm1(x), context=None, mask=mask, views=v) + x |
|
x = rearrange(x, "n (v hw) c -> n v hw c", v=v) |
|
|
|
|
|
x = rearrange(x, "n v hw c -> (n v) hw c") |
|
context = rearrange(context, "n v seq d -> (n v) seq d") |
|
x = self.attn2(self.norm2(x), context=context) + x |
|
x = self.ff(self.norm3(x)) + x |
|
x = rearrange(x, "(n v) hw c -> n v hw c", v=v) |
|
|
|
return x |
|
|
|
|
|
class SPADTransformer(nn.Module): |
|
"""Spatial Transformer block with post init to add cross attn.""" |
|
|
|
def __init__( |
|
self, |
|
in_channels, |
|
n_heads, |
|
d_head, |
|
depth=1, |
|
dropout=0.0, |
|
context_dim=None, |
|
disable_self_attn=False, |
|
use_linear=False, |
|
use_checkpoint=True, |
|
): |
|
super().__init__() |
|
if exists(context_dim) and not isinstance(context_dim, list): |
|
context_dim = [context_dim] |
|
self.in_channels = in_channels |
|
inner_dim = n_heads * d_head |
|
self.norm = Normalize(in_channels) |
|
if not use_linear: |
|
self.proj_in = nn.Conv2d( |
|
in_channels, inner_dim, kernel_size=1, stride=1, padding=0 |
|
) |
|
else: |
|
self.proj_in = nn.Linear(in_channels, inner_dim) |
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
SPADTransformerBlock( |
|
inner_dim, |
|
n_heads, |
|
d_head, |
|
dropout=dropout, |
|
context_dim=context_dim[d], |
|
disable_self_attn=disable_self_attn, |
|
checkpoint=use_checkpoint, |
|
) |
|
for d in range(depth) |
|
] |
|
) |
|
if not use_linear: |
|
self.proj_out = zero_module( |
|
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) |
|
) |
|
else: |
|
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) |
|
self.use_linear = use_linear |
|
|
|
|
|
self.post_init() |
|
|
|
def post_init(self): |
|
assert getattr(self, "post_intialized", False) is False, "already modified!" |
|
|
|
|
|
conv_block = self.proj_in |
|
conv_params = { |
|
k: getattr(conv_block, k) |
|
for k in [ |
|
"in_channels", |
|
"out_channels", |
|
"kernel_size", |
|
"stride", |
|
"padding", |
|
] |
|
} |
|
conv_params["in_channels"] += 6 |
|
conv_params["dims"] = 2 |
|
conv_params["device"] = conv_block.weight.device |
|
|
|
|
|
inflated_proj_in = conv_nd(**conv_params) |
|
inp_weight = conv_block.weight.data |
|
feat_shape = inp_weight.shape |
|
|
|
|
|
feat_weight = torch.zeros( |
|
(feat_shape[0], 6, *feat_shape[2:]), device=inp_weight.device |
|
) |
|
|
|
|
|
inflated_proj_in.weight.data.copy_( |
|
torch.cat([inp_weight, feat_weight], dim=1) |
|
) |
|
inflated_proj_in.bias.data.copy_(conv_block.bias.data) |
|
self.proj_in = inflated_proj_in |
|
self.post_intialized = True |
|
|
|
def forward(self, x, context=None): |
|
return self.spad_forward(x, context=context) |
|
|
|
def spad_forward(self, x, context=None): |
|
""" |
|
x: tensor of shape [n, v, c (4), h (32), w (32)] |
|
context: list of [text_emb, epipolar_mask, plucker_coords] |
|
- text_emb: tensor of shape [n, v, seq_len (77), dim (768)] |
|
- epipolar_mask: bool tensor of shape [n, v, seq_len (32*32), seq_len (32*32)] |
|
- plucker_coords: tensor of shape [n, v, dim (6), h (32), w (32)] |
|
""" |
|
|
|
n_objects, n_views, c, h, w = x.shape |
|
x_in = x |
|
|
|
|
|
context, plucker = context[:-1], context[-1] |
|
context = [context] |
|
|
|
x = rearrange(x, "n v c h w -> (n v) c h w") |
|
x = self.norm(x) |
|
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views) |
|
|
|
|
|
if not self.use_linear: |
|
|
|
plucker = rearrange(plucker, "n v c h w -> (n v) c h w") |
|
plucker_interpolated = F.interpolate( |
|
plucker, size=x.shape[-2:], align_corners=False, mode="bilinear" |
|
) |
|
plucker_interpolated = rearrange( |
|
plucker_interpolated, "(n v) c h w -> n v c h w", v=n_views |
|
) |
|
|
|
|
|
x = torch.cat([x, plucker_interpolated], dim=2) |
|
x = rearrange(x, "n v c h w -> (n v) c h w") |
|
x = self.proj_in(x) |
|
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views) |
|
|
|
x = rearrange(x, "n v c h w -> n v (h w) c").contiguous() |
|
|
|
if self.use_linear: |
|
x = rearrange(x, "n v x c -> (n v) x c") |
|
x = self.proj_in(x) |
|
x = rearrange(x, "(n v) x c -> n v x c", v=n_views) |
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
_context = context[i] |
|
mask = None |
|
if isinstance(_context, (list, tuple)): |
|
try: |
|
_context, mask = _context |
|
except: |
|
_context = _context[0] |
|
x = block(x, context=_context, mask=mask) |
|
|
|
if x.isnan().any(): |
|
breakpoint() |
|
|
|
|
|
if self.use_linear: |
|
x = rearrange(x, "n v x c -> (n v) x c") |
|
x = self.proj_out(x) |
|
x = rearrange(x, "(n v) x c -> n v x c", v=n_views) |
|
|
|
x = rearrange(x, "n v (h w) c -> n v c h w", h=h, w=w).contiguous() |
|
|
|
if not self.use_linear: |
|
x = rearrange(x, "n v c h w -> (n v) c h w") |
|
x = self.proj_out(x) |
|
x = rearrange(x, "(n v) c h w -> n v c h w", v=n_views) |
|
|
|
return x + x_in |
|
|
|
|
|
if __name__ == "__main__": |
|
spt_post = SPADTransformer(320, 8, 40, depth=1, context_dim=768).cuda() |
|
|
|
n_objects, n_views = 2, 4 |
|
x = torch.randn(2, 4, 320, 32, 32).cuda() |
|
context = [ |
|
torch.randn(n_objects, n_views, 77, 768).cuda(), |
|
torch.ones( |
|
n_objects, n_views * 32 * 32, n_views * 32 * 32, dtype=torch.bool |
|
).cuda(), |
|
torch.randn(n_objects, n_views, 6, 32, 32).cuda(), |
|
] |
|
x_post = spt_post(x, context=context) |