# This source code is written based on https://github.com/facebookresearch/MCC # The original code base is licensed under the license found in the LICENSE file in the root directory. import torch import torch.nn as nn import torchvision from functools import partial from timm.models.vision_transformer import Block, PatchEmbed from utils.pos_embed import get_2d_sincos_pos_embed from utils.layers import Bottleneck_Conv class RGBEncAtt(nn.Module): """ Seen surface encoder based on transformer. """ def __init__(self, img_size=224, embed_dim=768, n_blocks=12, num_heads=12, win_size=16, mlp_ratio=4., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path=0.1): super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.rgb_embed = PatchEmbed(img_size, win_size, 3, embed_dim) num_patches = self.rgb_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) self.blocks = nn.ModuleList([ Block( embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path ) for _ in range(n_blocks)]) self.norm = norm_layer(embed_dim) self.initialize_weights() def initialize_weights(self): # initialize the pos enc with fixed cos-sin pattern pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.rgb_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize rgb patch_embed like nn.Linear (instead of nn.Conv2d) w = self.rgb_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, rgb_obj): # [B, H/ws*W/ws, C] rgb_embedding = self.rgb_embed(rgb_obj) rgb_embedding = rgb_embedding + self.pos_embed[:, 1:, :] # append cls token # [1, 1, C] cls_token = self.cls_token + self.pos_embed[:, :1, :] # [B, 1, C] cls_tokens = cls_token.expand(rgb_embedding.shape[0], -1, -1) # [B, H/ws*W/ws+1, C] rgb_embedding = torch.cat((cls_tokens, rgb_embedding), dim=1) # apply Transformer blocks for blk in self.blocks: rgb_embedding = blk(rgb_embedding) rgb_embedding = self.norm(rgb_embedding) # [B, H/ws*W/ws+1, C] return rgb_embedding class RGBEncRes(nn.Module): """ RGB encoder based on resnet. """ def __init__(self, opt): super().__init__() self.encoder = torchvision.models.resnet50(pretrained=True) self.encoder.fc = nn.Sequential( Bottleneck_Conv(2048), Bottleneck_Conv(2048), nn.Linear(2048, opt.arch.latent_dim) ) # define hooks self.rgb_feature = None def feature_hook(model, input, output): self.rgb_feature = output # attach hooks if (opt.arch.win_size) == 16: self.encoder.layer3.register_forward_hook(feature_hook) self.rgb_feat_proj = nn.Sequential( Bottleneck_Conv(1024), Bottleneck_Conv(1024), nn.Conv2d(1024, opt.arch.latent_dim, 1) ) elif (opt.arch.win_size) == 32: self.encoder.layer4.register_forward_hook(feature_hook) self.rgb_feat_proj = nn.Sequential( Bottleneck_Conv(2048), Bottleneck_Conv(2048), nn.Conv2d(2048, opt.arch.latent_dim, 1) ) else: print('Make sure win_size is 16 or 32 when using resnet backbone!') raise NotImplementedError def forward(self, rgb_obj): batch_size = rgb_obj.shape[0] assert len(rgb_obj.shape) == 4 # [B, 1, C] global_feat = self.encoder(rgb_obj).unsqueeze(1) # [B, C, H/ws*W/ws] local_feat = self.rgb_feat_proj(self.rgb_feature).view(batch_size, global_feat.shape[-1], -1) # [B, H/ws*W/ws, C] local_feat = local_feat.permute(0, 2, 1).contiguous() # [B, 1+H/ws*W/ws, C] rgb_embedding = torch.cat([global_feat, local_feat], dim=1) return rgb_embedding