# ------------------------------------------------------------------------------ # Copyright and License Information # Adapted from # https://github.com/microsoft/voxelpose-pytorch/blob/main/lib/models/v2v_net.py # Original Licence: MIT License # ------------------------------------------------------------------------------ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule from ..builder import BACKBONES from .base_backbone import BaseBackbone class Basic3DBlock(nn.Module): """A basic 3D convolutional block. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. kernel_size (int): Kernel size of the convolution operation conv_cfg (dict): Dictionary to construct and config conv layer. Default: dict(type='Conv3d') norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN3d') """ def __init__(self, in_channels, out_channels, kernel_size, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d')): super(Basic3DBlock, self).__init__() self.block = ConvModule( in_channels, out_channels, kernel_size, stride=1, padding=((kernel_size - 1) // 2), conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True) def forward(self, x): """Forward function.""" return self.block(x) class Res3DBlock(nn.Module): """A residual 3D convolutional block. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. kernel_size (int): Kernel size of the convolution operation Default: 3 conv_cfg (dict): Dictionary to construct and config conv layer. Default: dict(type='Conv3d') norm_cfg (dict): Dictionary to construct and config norm layer. Default: dict(type='BN3d') """ def __init__(self, in_channels, out_channels, kernel_size=3, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d')): super(Res3DBlock, self).__init__() self.res_branch = nn.Sequential( ConvModule( in_channels, out_channels, kernel_size, stride=1, padding=((kernel_size - 1) // 2), conv_cfg=conv_cfg, norm_cfg=norm_cfg, bias=True), ConvModule( out_channels, out_channels, kernel_size, stride=1, padding=((kernel_size - 1) // 2), conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, bias=True)) if in_channels == out_channels: self.skip_con = nn.Sequential() else: self.skip_con = ConvModule( in_channels, out_channels, 1, stride=1, padding=0, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None, bias=True) def forward(self, x): """Forward function.""" res = self.res_branch(x) skip = self.skip_con(x) return F.relu(res + skip, True) class Pool3DBlock(nn.Module): """A 3D max-pool block. Args: pool_size (int): Pool size of the 3D max-pool layer """ def __init__(self, pool_size): super(Pool3DBlock, self).__init__() self.pool_size = pool_size def forward(self, x): """Forward function.""" return F.max_pool3d( x, kernel_size=self.pool_size, stride=self.pool_size) class Upsample3DBlock(nn.Module): """A 3D upsample block. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. kernel_size (int): Kernel size of the transposed convolution operation. Default: 2 stride (int): Kernel size of the transposed convolution operation. Default: 2 """ def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): super(Upsample3DBlock, self).__init__() assert kernel_size == 2 assert stride == 2 self.block = nn.Sequential( nn.ConvTranspose3d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0), nn.BatchNorm3d(out_channels), nn.ReLU(True)) def forward(self, x): """Forward function.""" return self.block(x) class EncoderDecorder(nn.Module): """An encoder-decoder block. Args: in_channels (int): Input channels of this block """ def __init__(self, in_channels=32): super(EncoderDecorder, self).__init__() self.encoder_pool1 = Pool3DBlock(2) self.encoder_res1 = Res3DBlock(in_channels, in_channels * 2) self.encoder_pool2 = Pool3DBlock(2) self.encoder_res2 = Res3DBlock(in_channels * 2, in_channels * 4) self.mid_res = Res3DBlock(in_channels * 4, in_channels * 4) self.decoder_res2 = Res3DBlock(in_channels * 4, in_channels * 4) self.decoder_upsample2 = Upsample3DBlock(in_channels * 4, in_channels * 2, 2, 2) self.decoder_res1 = Res3DBlock(in_channels * 2, in_channels * 2) self.decoder_upsample1 = Upsample3DBlock(in_channels * 2, in_channels, 2, 2) self.skip_res1 = Res3DBlock(in_channels, in_channels) self.skip_res2 = Res3DBlock(in_channels * 2, in_channels * 2) def forward(self, x): """Forward function.""" skip_x1 = self.skip_res1(x) x = self.encoder_pool1(x) x = self.encoder_res1(x) skip_x2 = self.skip_res2(x) x = self.encoder_pool2(x) x = self.encoder_res2(x) x = self.mid_res(x) x = self.decoder_res2(x) x = self.decoder_upsample2(x) x = x + skip_x2 x = self.decoder_res1(x) x = self.decoder_upsample1(x) x = x + skip_x1 return x @BACKBONES.register_module() class V2VNet(BaseBackbone): """V2VNet. Please refer to the `paper ` for details. Args: input_channels (int): Number of channels of the input feature volume. output_channels (int): Number of channels of the output volume. mid_channels (int): Input and output channels of the encoder-decoder block. """ def __init__(self, input_channels, output_channels, mid_channels=32): super(V2VNet, self).__init__() self.front_layers = nn.Sequential( Basic3DBlock(input_channels, mid_channels // 2, 7), Res3DBlock(mid_channels // 2, mid_channels), ) self.encoder_decoder = EncoderDecorder(in_channels=mid_channels) self.output_layer = nn.Conv3d( mid_channels, output_channels, kernel_size=1, stride=1, padding=0) self._initialize_weights() def forward(self, x): """Forward function.""" x = self.front_layers(x) x = self.encoder_decoder(x) x = self.output_layer(x) return x def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv3d): nn.init.normal_(m.weight, 0, 0.001) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose3d): nn.init.normal_(m.weight, 0, 0.001) nn.init.constant_(m.bias, 0)