Spaces:
Sleeping
Sleeping
import torch | |
class Decoder(torch.nn.Module): | |
def __init__(self, cfg): | |
super(Decoder, self).__init__() | |
self.cfg = cfg | |
# Layer Definition | |
self.layer1 = torch.nn.Sequential( | |
torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), | |
torch.nn.BatchNorm3d(512), | |
torch.nn.ReLU() | |
) | |
self.layer2 = torch.nn.Sequential( | |
torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), | |
torch.nn.BatchNorm3d(128), | |
torch.nn.ReLU() | |
) | |
self.layer3 = torch.nn.Sequential( | |
torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), | |
torch.nn.BatchNorm3d(32), | |
torch.nn.ReLU() | |
) | |
self.layer4 = torch.nn.Sequential( | |
torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1), | |
torch.nn.BatchNorm3d(8), | |
torch.nn.ReLU() | |
) | |
self.layer5 = torch.nn.Sequential( | |
torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS), | |
torch.nn.Sigmoid() | |
) | |
def forward(self, image_features): | |
image_features = image_features.permute(1, 0, 2, 3, 4).contiguous() | |
image_features = torch.split(image_features, 1, dim=0) | |
gen_volumes = [] | |
raw_features = [] | |
for features in image_features: | |
gen_volume = features.view(-1, 2048, 2, 2, 2) | |
# print(gen_volume.size()) # torch.Size([batch_size, 2048, 2, 2, 2]) | |
gen_volume = self.layer1(gen_volume) | |
# print(gen_volume.size()) # torch.Size([batch_size, 512, 4, 4, 4]) | |
gen_volume = self.layer2(gen_volume) | |
# print(gen_volume.size()) # torch.Size([batch_size, 128, 8, 8, 8]) | |
gen_volume = self.layer3(gen_volume) | |
# print(gen_volume.size()) # torch.Size([batch_size, 32, 16, 16, 16]) | |
gen_volume = self.layer4(gen_volume) | |
raw_feature = gen_volume | |
# print(gen_volume.size()) # torch.Size([batch_size, 8, 32, 32, 32]) | |
gen_volume = self.layer5(gen_volume) | |
# print(gen_volume.size()) # torch.Size([batch_size, 1, 32, 32, 32]) | |
raw_feature = torch.cat((raw_feature, gen_volume), dim=1) | |
# print(raw_feature.size()) # torch.Size([batch_size, 9, 32, 32, 32]) | |
gen_volumes.append(torch.squeeze(gen_volume, dim=1)) | |
raw_features.append(raw_feature) | |
gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous() | |
raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous() | |
# print(gen_volumes.size()) # torch.Size([batch_size, n_views, 32, 32, 32]) | |
# print(raw_features.size()) # torch.Size([batch_size, n_views, 9, 32, 32, 32]) | |
return raw_features, gen_volumes | |
class DummyCfg: | |
class NETWORK: | |
TCONV_USE_BIAS = False | |
cfg = DummyCfg() | |
# Instantiate the decoder | |
decoder = Decoder(cfg) | |
# Simulate input: shape [batch_size,n_views,img_c, img_h, img_w] | |
n_views = 1 | |
batch_size = 64 | |
img_c, img_h, img_w = 256, 8, 8 | |
dummy_input = torch.randn(batch_size,n_views,img_c, img_h, img_w) | |
# Run the decoder | |
print(dummy_input.shape) | |
raw_features, gen_volumes = decoder(dummy_input) | |
# Output shapes | |
print("raw_features shape:", raw_features.shape) # Expected: [64, 5, 9, 32, 32, 32] | |
print("gen_volumes shape:", gen_volumes.shape) | |