ZeroShape / model /shape /implicit.py
zxhuang1698's picture
initial commit
414b431
raw
history blame
12 kB
import numpy as np
import torch
import torch.nn as nn
from functools import partial
from utils.layers import get_embedder
from utils.layers import LayerScale
from timm.models.vision_transformer import Mlp, DropPath
from utils.pos_embed import get_2d_sincos_pos_embed
class ImplFuncAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., last_layer=False):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.last_layer = last_layer
def forward(self, x, N_points):
B, N, C = x.shape
N_latent = N - N_points
# [3, B, num_heads, N, C/num_heads]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [B, num_heads, N, C/num_heads]
q, k, v = qkv.unbind(0)
# [B, num_heads, N_latent, C/num_heads]
q_latent, k_latent, v_latent = q[:, :, :-N_points], k[:, :, :-N_points], v[:, :, :-N_points]
# [B, num_heads, N_points, C/num_heads]
q_points, k_points, v_points = q[:, :, -N_points:], k[:, :, -N_points:], v[:, :, -N_points:]
# attention weight for each point, it's only connected to the latent and itself
# [B, num_heads, N_points, N_latent+1]
# get the cross attention, [B, num_heads, N_points, N_latent]
attn_cross = (q_points @ k_latent.transpose(-2, -1)) * self.scale
# get the attention to self feature, [B, num_heads, N_points, 1]
attn_self = torch.sum(q_points * k_points, dim=-1, keepdim=True) * self.scale
# get the normalized attention, [B, num_heads, N_points, N_latent+1]
attn_joint = torch.cat([attn_cross, attn_self], dim=-1)
attn_joint = attn_joint.softmax(dim=-1)
attn_joint = self.attn_drop(attn_joint)
# break it down to weigh and sum the values
# [B, num_heads, N_points, N_latent] @ [B, num_heads, N_latent, C/num_heads]
# -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
sum_cross = (attn_joint[:, :, :, :N_latent] @ v_latent).transpose(1, 2).reshape(B, N_points, C)
# [B, num_heads, N_points, 1] * [B, num_heads, N_points, C/num_heads]
# -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C]
sum_self = (attn_joint[:, :, :, N_latent:] * v_points).transpose(1, 2).reshape(B, N_points, C)
# [B, N_points, C]
output_points = sum_cross + sum_self
if self.last_layer:
output = self.proj(output_points)
output = self.proj_drop(output)
# [B, N_points, C], [B, N_points, N_latent]
return output, attn_joint[..., :-1].mean(dim=1)
# attention weight for the latent vec, it's not connected to the points
# [B, num_heads, N_latent, N_latent]
attn_latent = (q_latent @ k_latent.transpose(-2, -1)) * self.scale
attn_latent = attn_latent.softmax(dim=-1)
attn_latent = self.attn_drop(attn_latent)
# get the output latent, [B, N_latent, C]
output_latent = (attn_latent @ v_latent).transpose(1, 2).reshape(B, N_latent, C)
# concatenate the output and return
output = torch.cat([output_latent, output_points], dim=1)
output = self.proj(output)
output = self.proj_drop(output)
# [B, N, C], [B, N_points, N_latent+1]
return output, attn_joint[..., :-1].mean(dim=1)
class ImplFuncBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, last_layer=False):
super().__init__()
self.last_layer = last_layer
self.norm1 = norm_layer(dim)
self.attn = ImplFuncAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, last_layer=last_layer)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, unseen_size):
if self.last_layer:
attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
output = x[:, -unseen_size:] + self.drop_path1(self.ls1(attn_out))
output = output + self.drop_path2(self.ls2(self.mlp(self.norm2(output))))
return output, attn_vis
else:
attn_out, attn_vis = self.attn(self.norm1(x), unseen_size)
x = x + self.drop_path1(self.ls1(attn_out))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x, attn_vis
class LinearProj3D(nn.Module):
"""
Linear projection of 3D point into embedding space
"""
def __init__(self, embed_dim, posenc_res=0):
super().__init__()
self.embed_dim = embed_dim
# define positional embedder
self.embed_fn = None
input_ch = 3
if posenc_res > 0:
self.embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
# linear proj layer
self.proj = nn.Linear(input_ch, embed_dim)
def forward(self, points_3D):
if self.embed_fn is not None:
points_3D = self.embed_fn(points_3D)
return self.proj(points_3D)
class MLPBlocks(nn.Module):
def __init__(self, num_hidden_layers, n_channels, latent_dim,
skip_in=[], posenc_res=0):
super().__init__()
# projection to the same number of channels
self.dims = [3 + latent_dim] + [n_channels] * num_hidden_layers + [1]
self.num_layers = len(self.dims)
self.skip_in = skip_in
# define positional embedder
self.embed_fn = None
if posenc_res > 0:
embed_fn, input_ch = get_embedder(posenc_res, input_dims=3)
self.embed_fn = embed_fn
self.dims[0] += (input_ch - 3)
self.layers = nn.ModuleList([])
for l in range(0, self.num_layers - 1):
out_dim = self.dims[l + 1]
if l in self.skip_in:
in_dim = self.dims[l] + self.dims[0]
else:
in_dim = self.dims[l]
lin = nn.Linear(in_dim, out_dim)
self.layers.append(lin)
# register for param init
self.posenc_res = posenc_res
# activation
self.softplus = nn.Softplus(beta=100)
def forward(self, points, proj_latent):
# positional encoding
if self.embed_fn is not None:
points = self.embed_fn(points)
# forward by layer
# [B, N, posenc+C]
inputs = torch.cat([points, proj_latent], dim=-1)
x = inputs
for l in range(0, self.num_layers - 1):
if l in self.skip_in:
x = torch.cat([x, inputs], -1) / np.sqrt(2)
x = self.layers[l](x)
if l < self.num_layers - 2:
x = self.softplus(x)
return x
class Implicit(nn.Module):
"""
Implicit function conditioned on depth encodings
"""
def __init__(self,
num_patches, latent_dim=768, semantic=False, n_channels=512,
n_blocks_attn=2, n_layers_mlp=6, num_heads=16, posenc_3D=0,
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1,
skip_in=[], pos_perlayer=True):
super().__init__()
self.num_patches = num_patches
self.pos_perlayer = pos_perlayer
self.semantic = semantic
# projection to the same number of channels, no posenc
self.point_proj = LinearProj3D(n_channels)
self.latent_proj = nn.Linear(latent_dim, n_channels, bias=True)
# positional embedding for the depth latent codes
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, n_channels), requires_grad=False) # fixed sin-cos embedding
# multi-head attention blocks
self.blocks_attn = nn.ModuleList([
ImplFuncBlock(
n_channels, num_heads, mlp_ratio,
qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path
) for _ in range(n_blocks_attn-1)])
self.blocks_attn.append(
ImplFuncBlock(
n_channels, num_heads, mlp_ratio,
qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, last_layer=True
)
)
self.norm = norm_layer(n_channels)
self.impl_mlp = None
# define the impl MLP
if n_layers_mlp > 0:
self.impl_mlp = MLPBlocks(n_layers_mlp, n_channels, n_channels,
skip_in=skip_in, posenc_res=posenc_3D)
else:
# occ and color prediction
self.pred_head = nn.Linear(n_channels, 1, bias=True)
self.initialize_weights()
def initialize_weights(self):
# initialize the positional embedding for the depth latent codes
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, latent_depth, latent_semantic, points_3D):
# concatenate latent codes if semantic is used
latent = torch.cat([latent_depth, latent_semantic], dim=-1) if self.semantic else latent_depth
# project latent code and add posenc
# [B, 1+n_patches, C]
latent = self.latent_proj(latent)
N_latent = latent.shape[1]
# project query points
# [B, n_points, C_dec]
points_feat = self.point_proj(points_3D)
# concat point feat with latent
# [B, 1+n_patches+n_points, C_dec]
output = torch.cat([latent, points_feat], dim=1)
# apply multi-head attention blocks
attn_vis = []
for l, blk in enumerate(self.blocks_attn):
if self.pos_perlayer or l == 0:
output[:, :N_latent] = output[:, :N_latent] + self.pos_embed
output, attn = blk(output, points_feat.shape[1])
attn_vis.append(attn)
output = self.norm(output)
# average of attention weights across layers, [B, N_points, N_latent+1]
attn_vis = torch.stack(attn_vis, dim=-1).mean(dim=-1)
if self.impl_mlp:
# apply mlp blocks
output = self.impl_mlp(points_3D, output)
else:
# predictor projection
# [B, n_points, 1]
output = self.pred_head(output)
# return the occ logit of shape [B, n_points] and the attention weights if needed
return output.squeeze(-1), attn_vis