Spaces:
Build error
Build error
# This source code is written based on https://github.com/facebookresearch/MCC | |
# The original code base is licensed under the license found in the LICENSE file in the root directory. | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from functools import partial | |
from timm.models.vision_transformer import Block | |
from utils.pos_embed import get_2d_sincos_pos_embed | |
from utils.layers import Bottleneck_Conv | |
class CoordEmb(nn.Module): | |
""" | |
Encode the seen coordinate map to a lower resolution feature map | |
Achieved with window-wise attention block by deviding coord map into windows | |
Each window is seperately encoded into a single CLS token with self-attention and posenc | |
""" | |
def __init__(self, embed_dim, win_size=8, num_heads=8): | |
super().__init__() | |
self.embed_dim = embed_dim | |
self.win_size = win_size | |
self.two_d_pos_embed = nn.Parameter( | |
torch.zeros(1, self.win_size*self.win_size + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.pos_embed = nn.Linear(3, embed_dim) | |
self.blocks = nn.ModuleList([ | |
# each block is a residual block with layernorm -> attention -> layernorm -> mlp | |
Block(embed_dim, num_heads=num_heads, mlp_ratio=2.0, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) | |
for _ in range(1) | |
]) | |
self.invalid_coord_token = nn.Parameter(torch.zeros(embed_dim,)) | |
self.initialize_weights() | |
def initialize_weights(self): | |
torch.nn.init.normal_(self.cls_token, std=.02) | |
two_d_pos_embed = get_2d_sincos_pos_embed(self.two_d_pos_embed.shape[-1], self.win_size, cls_token=True) | |
self.two_d_pos_embed.data.copy_(torch.from_numpy(two_d_pos_embed).float().unsqueeze(0)) | |
torch.nn.init.normal_(self.invalid_coord_token, std=.02) | |
def forward(self, coord_obj, mask_obj): | |
# [B, H, W, C] | |
emb = self.pos_embed(coord_obj) | |
emb[~mask_obj] = 0.0 | |
emb[~mask_obj] += self.invalid_coord_token | |
B, H, W, C = emb.shape | |
# [B, H/ws, 8, W/ws, W, C] | |
emb = emb.view(B, H // self.win_size, self.win_size, W // self.win_size, self.win_size, C) | |
# [B * H/ws * W/ws, 64, C] | |
emb = emb.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.win_size * self.win_size, C) | |
# [B * H/ws * W/ws, 64, C], add posenc that is local to each patch | |
emb = emb + self.two_d_pos_embed[:, 1:, :] | |
# [1, 1, C] | |
cls_token = self.cls_token + self.two_d_pos_embed[:, :1, :] | |
# [B * H/ws * W/ws, 1, C] | |
cls_tokens = cls_token.expand(emb.shape[0], -1, -1) | |
# [B * H/ws * W/ws, 65, C] | |
emb = torch.cat((cls_tokens, emb), dim=1) | |
# transformer (single block) that handle each of the patch seperately | |
# reasoning is done within each batch | |
for _, blk in enumerate(self.blocks): | |
emb = blk(emb) | |
# return the cls token of each window, [B, H/ws*W/ws, C] | |
return emb[:, 0].view(B, (H // self.win_size) * (W // self.win_size), -1) | |
class CoordEncAtt(nn.Module): | |
""" | |
Seen surface encoder based on transformer. | |
""" | |
def __init__(self, | |
embed_dim=768, n_blocks=12, num_heads=12, win_size=8, | |
mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1): | |
super().__init__() | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
self.coord_embed = CoordEmb(embed_dim, win_size, num_heads) | |
self.blocks = nn.ModuleList([ | |
Block( | |
embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, | |
drop_path=drop_path | |
) for _ in range(n_blocks)]) | |
self.norm = norm_layer(embed_dim) | |
self.initialize_weights() | |
def initialize_weights(self): | |
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) | |
torch.nn.init.normal_(self.cls_token, std=.02) | |
# initialize nn.Linear and nn.LayerNorm | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
# we use xavier_uniform following official JAX ViT: | |
torch.nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def forward(self, coord_obj, mask_obj): | |
# [B, H/ws*W/ws, C] | |
coord_embedding = self.coord_embed(coord_obj, mask_obj) | |
# append cls token | |
# [1, 1, C] | |
cls_token = self.cls_token | |
# [B, 1, C] | |
cls_tokens = cls_token.expand(coord_embedding.shape[0], -1, -1) | |
# [B, H/ws*W/ws+1, C] | |
coord_embedding = torch.cat((cls_tokens, coord_embedding), dim=1) | |
# apply Transformer blocks | |
for blk in self.blocks: | |
coord_embedding = blk(coord_embedding) | |
coord_embedding = self.norm(coord_embedding) | |
# [B, H/ws*W/ws+1, C] | |
return coord_embedding | |
class CoordEncRes(nn.Module): | |
""" | |
Seen surface encoder based on resnet. | |
""" | |
def __init__(self, opt): | |
super().__init__() | |
self.encoder = torchvision.models.resnet50(pretrained=True) | |
self.encoder.fc = nn.Sequential( | |
Bottleneck_Conv(2048), | |
Bottleneck_Conv(2048), | |
nn.Linear(2048, opt.arch.latent_dim) | |
) | |
# define hooks | |
self.seen_feature = None | |
def feature_hook(model, input, output): | |
self.seen_feature = output | |
# attach hooks | |
assert opt.arch.depth.dsp == 1 | |
if (opt.arch.win_size) == 16: | |
self.encoder.layer3.register_forward_hook(feature_hook) | |
self.depth_feat_proj = nn.Sequential( | |
Bottleneck_Conv(1024), | |
Bottleneck_Conv(1024), | |
nn.Conv2d(1024, opt.arch.latent_dim, 1) | |
) | |
elif (opt.arch.win_size) == 32: | |
self.encoder.layer4.register_forward_hook(feature_hook) | |
self.depth_feat_proj = nn.Sequential( | |
Bottleneck_Conv(2048), | |
Bottleneck_Conv(2048), | |
nn.Conv2d(2048, opt.arch.latent_dim, 1) | |
) | |
else: | |
print('Make sure win_size is 16 or 32 when using resnet backbone!') | |
raise NotImplementedError | |
def forward(self, coord_obj, mask_obj): | |
batch_size = coord_obj.shape[0] | |
assert len(coord_obj.shape) == len(mask_obj.shape) == 4 | |
mask_obj = mask_obj.float() | |
coord_obj = coord_obj * mask_obj | |
# [B, 1, C] | |
global_feat = self.encoder(coord_obj).unsqueeze(1) | |
# [B, C, H/ws*W/ws] | |
local_feat = self.depth_feat_proj(self.seen_feature).view(batch_size, global_feat.shape[-1], -1) | |
# [B, H/ws*W/ws, C] | |
local_feat = local_feat.permute(0, 2, 1).contiguous() | |
# [B, 1+H/ws*W/ws, C] | |
seen_embedding = torch.cat([global_feat, local_feat], dim=1) | |
return seen_embedding |