import numpy as np import torch import torch.nn as nn from functools import partial from utils.layers import get_embedder from utils.layers import LayerScale from timm.models.vision_transformer import Mlp, DropPath from utils.pos_embed import get_2d_sincos_pos_embed class ImplFuncAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., last_layer=False): super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.last_layer = last_layer def forward(self, x, N_points): B, N, C = x.shape N_latent = N - N_points # [3, B, num_heads, N, C/num_heads] qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # [B, num_heads, N, C/num_heads] q, k, v = qkv.unbind(0) # [B, num_heads, N_latent, C/num_heads] q_latent, k_latent, v_latent = q[:, :, :-N_points], k[:, :, :-N_points], v[:, :, :-N_points] # [B, num_heads, N_points, C/num_heads] q_points, k_points, v_points = q[:, :, -N_points:], k[:, :, -N_points:], v[:, :, -N_points:] # attention weight for each point, it's only connected to the latent and itself # [B, num_heads, N_points, N_latent+1] # get the cross attention, [B, num_heads, N_points, N_latent] attn_cross = (q_points @ k_latent.transpose(-2, -1)) * self.scale # get the attention to self feature, [B, num_heads, N_points, 1] attn_self = torch.sum(q_points * k_points, dim=-1, keepdim=True) * self.scale # get the normalized attention, [B, num_heads, N_points, N_latent+1] attn_joint = torch.cat([attn_cross, attn_self], dim=-1) attn_joint = attn_joint.softmax(dim=-1) attn_joint = self.attn_drop(attn_joint) # break it down to weigh and sum the values # [B, num_heads, N_points, N_latent] @ [B, num_heads, N_latent, C/num_heads] # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C] sum_cross = (attn_joint[:, :, :, :N_latent] @ v_latent).transpose(1, 2).reshape(B, N_points, C) # [B, num_heads, N_points, 1] * [B, num_heads, N_points, C/num_heads] # -> [B, num_heads, N_points, C/num_heads] -> [B, N_points, C] sum_self = (attn_joint[:, :, :, N_latent:] * v_points).transpose(1, 2).reshape(B, N_points, C) # [B, N_points, C] output_points = sum_cross + sum_self if self.last_layer: output = self.proj(output_points) output = self.proj_drop(output) # [B, N_points, C], [B, N_points, N_latent] return output, attn_joint[..., :-1].mean(dim=1) # attention weight for the latent vec, it's not connected to the points # [B, num_heads, N_latent, N_latent] attn_latent = (q_latent @ k_latent.transpose(-2, -1)) * self.scale attn_latent = attn_latent.softmax(dim=-1) attn_latent = self.attn_drop(attn_latent) # get the output latent, [B, N_latent, C] output_latent = (attn_latent @ v_latent).transpose(1, 2).reshape(B, N_latent, C) # concatenate the output and return output = torch.cat([output_latent, output_points], dim=1) output = self.proj(output) output = self.proj_drop(output) # [B, N, C], [B, N_points, N_latent+1] return output, attn_joint[..., :-1].mean(dim=1) class ImplFuncBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, last_layer=False): super().__init__() self.last_layer = last_layer self.norm1 = norm_layer(dim) self.attn = ImplFuncAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, last_layer=last_layer) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x, unseen_size): if self.last_layer: attn_out, attn_vis = self.attn(self.norm1(x), unseen_size) output = x[:, -unseen_size:] + self.drop_path1(self.ls1(attn_out)) output = output + self.drop_path2(self.ls2(self.mlp(self.norm2(output)))) return output, attn_vis else: attn_out, attn_vis = self.attn(self.norm1(x), unseen_size) x = x + self.drop_path1(self.ls1(attn_out)) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x, attn_vis class LinearProj3D(nn.Module): """ Linear projection of 3D point into embedding space """ def __init__(self, embed_dim, posenc_res=0): super().__init__() self.embed_dim = embed_dim # define positional embedder self.embed_fn = None input_ch = 3 if posenc_res > 0: self.embed_fn, input_ch = get_embedder(posenc_res, input_dims=3) # linear proj layer self.proj = nn.Linear(input_ch, embed_dim) def forward(self, points_3D): if self.embed_fn is not None: points_3D = self.embed_fn(points_3D) return self.proj(points_3D) class MLPBlocks(nn.Module): def __init__(self, num_hidden_layers, n_channels, latent_dim, skip_in=[], posenc_res=0): super().__init__() # projection to the same number of channels self.dims = [3 + latent_dim] + [n_channels] * num_hidden_layers + [1] self.num_layers = len(self.dims) self.skip_in = skip_in # define positional embedder self.embed_fn = None if posenc_res > 0: embed_fn, input_ch = get_embedder(posenc_res, input_dims=3) self.embed_fn = embed_fn self.dims[0] += (input_ch - 3) self.layers = nn.ModuleList([]) for l in range(0, self.num_layers - 1): out_dim = self.dims[l + 1] if l in self.skip_in: in_dim = self.dims[l] + self.dims[0] else: in_dim = self.dims[l] lin = nn.Linear(in_dim, out_dim) self.layers.append(lin) # register for param init self.posenc_res = posenc_res # activation self.softplus = nn.Softplus(beta=100) def forward(self, points, proj_latent): # positional encoding if self.embed_fn is not None: points = self.embed_fn(points) # forward by layer # [B, N, posenc+C] inputs = torch.cat([points, proj_latent], dim=-1) x = inputs for l in range(0, self.num_layers - 1): if l in self.skip_in: x = torch.cat([x, inputs], -1) / np.sqrt(2) x = self.layers[l](x) if l < self.num_layers - 2: x = self.softplus(x) return x class Implicit(nn.Module): """ Implicit function conditioned on depth encodings """ def __init__(self, num_patches, latent_dim=768, semantic=False, n_channels=512, n_blocks_attn=2, n_layers_mlp=6, num_heads=16, posenc_3D=0, mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1, skip_in=[], pos_perlayer=True): super().__init__() self.num_patches = num_patches self.pos_perlayer = pos_perlayer self.semantic = semantic # projection to the same number of channels, no posenc self.point_proj = LinearProj3D(n_channels) self.latent_proj = nn.Linear(latent_dim, n_channels, bias=True) # positional embedding for the depth latent codes self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, n_channels), requires_grad=False) # fixed sin-cos embedding # multi-head attention blocks self.blocks_attn = nn.ModuleList([ ImplFuncBlock( n_channels, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path ) for _ in range(n_blocks_attn-1)]) self.blocks_attn.append( ImplFuncBlock( n_channels, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path, last_layer=True ) ) self.norm = norm_layer(n_channels) self.impl_mlp = None # define the impl MLP if n_layers_mlp > 0: self.impl_mlp = MLPBlocks(n_layers_mlp, n_channels, n_channels, skip_in=skip_in, posenc_res=posenc_3D) else: # occ and color prediction self.pred_head = nn.Linear(n_channels, 1, bias=True) self.initialize_weights() def initialize_weights(self): # initialize the positional embedding for the depth latent codes pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, latent_depth, latent_semantic, points_3D): # concatenate latent codes if semantic is used latent = torch.cat([latent_depth, latent_semantic], dim=-1) if self.semantic else latent_depth # project latent code and add posenc # [B, 1+n_patches, C] latent = self.latent_proj(latent) N_latent = latent.shape[1] # project query points # [B, n_points, C_dec] points_feat = self.point_proj(points_3D) # concat point feat with latent # [B, 1+n_patches+n_points, C_dec] output = torch.cat([latent, points_feat], dim=1) # apply multi-head attention blocks attn_vis = [] for l, blk in enumerate(self.blocks_attn): if self.pos_perlayer or l == 0: output[:, :N_latent] = output[:, :N_latent] + self.pos_embed output, attn = blk(output, points_feat.shape[1]) attn_vis.append(attn) output = self.norm(output) # average of attention weights across layers, [B, N_points, N_latent+1] attn_vis = torch.stack(attn_vis, dim=-1).mean(dim=-1) if self.impl_mlp: # apply mlp blocks output = self.impl_mlp(points_3D, output) else: # predictor projection # [B, n_points, 1] output = self.pred_head(output) # return the occ logit of shape [B, n_points] and the attention weights if needed return output.squeeze(-1), attn_vis