snap2scene / models /encoder.py
adirathor07's picture
initial commit
757ed1c
raw
history blame contribute delete
2.83 kB
# -*- coding: utf-8 -*-
#
# Developed by Haozhe Xie <cshzxie@gmail.com>
#
# References:
# - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py
import torch
import torchvision.models
class Encoder(torch.nn.Module):
def __init__(self, cfg):
super(Encoder, self).__init__()
self.cfg = cfg
# Layer Definition
vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)
self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(512, 512, kernel_size=3),
torch.nn.BatchNorm2d(512),
torch.nn.ELU(),
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(512, 512, kernel_size=3),
torch.nn.BatchNorm2d(512),
torch.nn.ELU(),
torch.nn.MaxPool2d(kernel_size=3)
)
self.layer3 = torch.nn.Sequential(
torch.nn.Conv2d(512, 256, kernel_size=1),
torch.nn.BatchNorm2d(256),
torch.nn.ELU()
)
# Don't update params in VGG16
for param in vgg16_bn.parameters():
param.requires_grad = False
def forward(self, rendering_images):
# print(rendering_images.size()) # torch.Size([batch_size, n_views, img_c, img_h, img_w])
rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()
rendering_images = torch.split(rendering_images, 1, dim=0)
image_features = []
for img in rendering_images:
features = self.vgg(img.squeeze(dim=0))
# print(features.size()) # torch.Size([batch_size, 512, 28, 28])
features = self.layer1(features)
# print(features.size()) # torch.Size([batch_size, 512, 26, 26])
features = self.layer2(features)
# print(features.size()) # torch.Size([batch_size, 512, 24, 24])
features = self.layer3(features)
# print(features.size()) # torch.Size([batch_size, 256, 8, 8])
image_features.append(features)
image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()
# print(image_features.size()) # torch.Size([batch_size, n_views, 256, 8, 8])
return image_features
class DummyCfg:
class NETWORK:
TCONV_USE_BIAS = False
cfg = DummyCfg()
# Instantiate the decoder
encoder = Encoder(cfg)
# Simulate input: shape [batch_size,n_views,img_c, img_h, img_w]
batch_size = 64
n_views=5
img_c, img_h, img_w = 3,224,224
dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w)
# Run the decoder
print(dummy_input.shape)
image_features = encoder(dummy_input)
print("image_features shape:", image_features.shape) # Expected: [64, 5, 9, 32, 32, 32]