DIY-SC / model_utils /projection_network.py
odunkel's picture
Upload 9 files
79cc514 verified
import numpy as np
import torch
from torch import nn
from model_utils.resnet import ResNet, BottleneckBlock
import torch.nn.functional as F
class DummyAggregationNetwork(nn.Module): # for testing, return the input
def __init__(self):
super(DummyAggregationNetwork, self).__init__()
# dummy paprameter
self.dummy = nn.Parameter(torch.ones([]))
def forward(self, batch, pose=None):
return batch * self.dummy
class AggregationNetwork(nn.Module):
"""
Module for aggregating feature maps across time and space.
Design inspired by the Feature Extractor from ODISE (Xu et. al., CVPR 2023).
https://github.com/NVlabs/ODISE/blob/5836c0adfcd8d7fd1f8016ff5604d4a31dd3b145/odise/modeling/backbone/feature_extractor.py
"""
def __init__(
self,
device,
feature_dims=[640,1280,1280,768],
projection_dim=384,
num_norm_groups=32,
save_timestep=[1],
kernel_size = [1,3,1],
contrastive_temp = 10,
feat_map_dropout=0.0,
num_blocks=None,
bottleneck_channels=None
):
super().__init__()
self.skip_connection = True
self.feat_map_dropout = feat_map_dropout
self.azimuth_embedding = None
self.pos_embedding = None
self.bottleneck_layers = nn.ModuleList()
self.feature_dims = feature_dims
self.num_blocks = num_blocks if num_blocks is not None else 1
self.bottleneck_channels = bottleneck_channels if bottleneck_channels is not None else projection_dim//4
# For CLIP symmetric cross entropy loss during training
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.self_logit_scale = nn.Parameter(torch.ones([]) * np.log(contrastive_temp))
self.device = device
self.save_timestep = save_timestep
self.mixing_weights_names = []
for l, feature_dim in enumerate(self.feature_dims):
bottleneck_layer = nn.Sequential(
*ResNet.make_stage(
BottleneckBlock,
num_blocks=self.num_blocks,
in_channels=feature_dim,
bottleneck_channels=self.bottleneck_channels,
out_channels=projection_dim,
norm="GN",
num_norm_groups=num_norm_groups,
kernel_size=kernel_size
)
)
self.bottleneck_layers.append(bottleneck_layer)
for t in save_timestep:
# 1-index the layer name following prior work
self.mixing_weights_names.append(f"timestep-{save_timestep}_layer-{l+1}")
self.last_layer = None
self.bottleneck_layers = self.bottleneck_layers.to(device)
mixing_weights = torch.ones(len(self.bottleneck_layers) * len(save_timestep))
self.mixing_weights = nn.Parameter(mixing_weights.to(device))
# count number of parameters
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(f"AggregationNetwork has {num_params} parameters.")
def load_pretrained_weights(self, pretrained_dict):
custom_dict = self.state_dict()
# Handle size mismatch
if 'mixing_weights' in custom_dict and 'mixing_weights' in pretrained_dict and custom_dict['mixing_weights'].shape != pretrained_dict['mixing_weights'].shape:
# Keep the first four weights from the pretrained model, and randomly initialize the fifth weight
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
custom_dict['mixing_weights'][4] = torch.zeros_like(custom_dict['mixing_weights'][4])
else:
custom_dict['mixing_weights'][:4] = pretrained_dict['mixing_weights'][:4]
# Load the weights that do match
matching_keys = {k: v for k, v in pretrained_dict.items() if k in custom_dict and k != 'mixing_weights'}
custom_dict.update(matching_keys)
# Now load the updated state_dict
self.load_state_dict(custom_dict, strict=False)
def forward(self, batch, pose=None):
"""
Assumes batch is shape (B, C, H, W) where C is the concatentation of all layer features.
"""
if self.feat_map_dropout > 0 and self.training:
batch = F.dropout(batch, p=self.feat_map_dropout)
output_feature = None
start = 0
mixing_weights = torch.nn.functional.softmax(self.mixing_weights, dim=0)
if self.pos_embedding is not None: #position embedding
batch = torch.cat((batch, self.pos_embedding), dim=1)
for i in range(len(mixing_weights)):
# Share bottleneck layers across timesteps
bottleneck_layer = self.bottleneck_layers[i % len(self.feature_dims)]
# Chunk the batch according the layer
# Account for looping if there are multiple timesteps
end = start + self.feature_dims[i % len(self.feature_dims)]
feats = batch[:, start:end, :, :]
start = end
# Downsample the number of channels and weight the layer
bottlenecked_feature = bottleneck_layer(feats)
bottlenecked_feature = mixing_weights[i] * bottlenecked_feature
if output_feature is None:
output_feature = bottlenecked_feature
else:
output_feature += bottlenecked_feature
if self.last_layer is not None:
output_feature_after = self.last_layer(output_feature)
if self.skip_connection:
# skip connection
output_feature = output_feature + output_feature_after
return output_feature
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution without padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.conv2 = conv3x3(planes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
conv1x1(in_planes, planes, stride=stride),
nn.BatchNorm2d(planes)
)
def forward(self, x):
y = x
y = self.relu(self.bn1(self.conv1(y)))
y = self.bn2(self.conv2(y))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)