| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Decoder for Panoptic Recon 3D.""" |
|
|
| from typing import Optional, List |
| import torch |
| from torch import nn |
| import MinkowskiEngine as Me |
| from ..utils.sparse_tensor import sparse_cat_union |
| from ..blocks import BasicBlock3D, SparseBasicBlock3D |
|
|
|
|
| class SparseToDense(nn.Module): |
| """Sparse to dense module.""" |
|
|
| def __init__(self, input_size): |
| """Initialize the sparse to dense module.""" |
| super().__init__() |
| assert len(input_size) == 3 |
| self.input_size = input_size |
|
|
| def forward(self, feature: Me.SparseTensor) -> torch.Tensor: |
| """Forward pass.""" |
| batch_size = len(feature.decomposed_coordinates_and_features[0]) |
| feat_dim = feature.C.shape[-1] |
|
|
| out_size = ( |
| torch.div( |
| torch.tensor(self.input_size), |
| torch.tensor(feature.tensor_stride), |
| rounding_mode="floor" |
| ) |
| ).tolist() |
| shape = torch.Size([batch_size, feat_dim, *out_size]) |
| min_coordinate = torch.IntTensor([0, 0, 0]) |
|
|
| mask = (feature.C[:, 1] < self.input_size[0]) & \ |
| (feature.C[:, 2] < self.input_size[1]) & \ |
| (feature.C[:, 3] < self.input_size[2]) |
| mask = mask & (feature.C[:, 1] >= 0) & (feature.C[:, 2] >= 0) & (feature.C[:, 3] >= 0) |
|
|
| feature = Me.MinkowskiPruning()(feature, mask) |
| dense = feature.dense(shape, min_coordinate=min_coordinate)[0] |
|
|
| return dense |
|
|
|
|
| class FrustumDecoder(nn.Module): |
| """Frustum decoder module.""" |
|
|
| def __init__(self, cfg) -> None: |
| """Initialize the frustum decoder module.""" |
| super().__init__() |
|
|
| num_output_features = cfg.model.frustum3d.unet_output_channels |
| num_features = cfg.model.frustum3d.unet_features |
| sign_channel = cfg.model.projection.sign_channel |
| mask_dim = cfg.model.sem_seg_head.mask_dim |
| depth_dim = cfg.model.sem_seg_head.depth_dim |
| num_classes = cfg.model.sem_seg_head.num_classes |
| frustum_dims = cfg.model.frustum3d.grid_dimensions |
| frustum_dims = [frustum_dims] * 3 |
|
|
| self.use_ms_features = cfg.model.frustum3d.use_multi_scale |
| self.truncation = cfg.model.frustum3d.truncation |
|
|
| if cfg.dataset.name == 'matterport': |
| ms_feature_channels = cfg.model.sem_seg_head.convs_dim |
| else: |
| ms_feature_channels = cfg.model.sem_seg_head.convs_dim + \ |
| cfg.model.sem_seg_head.num_classes + cfg.model.frustum3d.signed_channel |
|
|
| |
| self.input_dims = [2 if sign_channel else 1, mask_dim + depth_dim, num_classes] |
| self.input_encoders = nn.ModuleList() |
| for input_dim in self.input_dims: |
| downsample = nn.Sequential( |
| Me.MinkowskiConvolution( |
| input_dim, num_features, |
| kernel_size=1, stride=1, |
| bias=True, dimension=3 |
| ), |
| Me.MinkowskiInstanceNorm(num_features), |
| ) |
| self.input_encoders.append( |
| SparseBasicBlock3D( |
| input_dim, num_features, |
| downsample=downsample |
| ) |
| ) |
|
|
| self.level_encoders = nn.ModuleList([ |
| self.make_encoder(len(self.input_encoders) * num_features, num_features), |
| self.make_encoder(num_features, num_features * 2), |
| self.make_encoder(num_features * 2, num_features * 4, is_sparse=False), |
| self.make_encoder(num_features * 4, num_features * 8, is_sparse=False), |
| self.make_encoder(num_features * 8, num_features * 8, is_sparse=False), |
| ]) |
|
|
| sparse_to_dense = SparseToDense(frustum_dims) |
|
|
| if self.use_ms_features: |
| self.feature_adapters = nn.ModuleList([ |
| self.make_adapter(ms_feature_channels, num_features), |
| self.make_adapter(ms_feature_channels, num_features * 2), |
| self.make_adapter(ms_feature_channels, num_features * 4, [sparse_to_dense]), |
| ]) |
| else: |
| self.feature_adapters = None |
|
|
| self.enc_level_conversion = nn.ModuleList([ |
| nn.Identity(), |
| sparse_to_dense, |
| nn.Identity(), |
| nn.Identity(), |
| ]) |
|
|
| self.level_decoders = nn.ModuleList([ |
| self.make_decoder(num_features * 3, num_output_features), |
| self.make_decoder( |
| num_features * 6, num_features * 2, |
| extra_layers=[SparseBasicBlock3D(num_features * 2, num_features * 2)] |
| ), |
| self.make_decoder(num_features * 8, num_features * 2, is_sparse=False), |
| self.make_decoder(num_features * 16, num_features * 4, is_sparse=False), |
| self.make_decoder(num_features * 8, num_features * 8, is_sparse=False), |
| ]) |
|
|
| |
| self.level_occupancy_heads = nn.ModuleList([ |
| nn.Sequential( |
| Me.MinkowskiInstanceNorm(num_output_features), |
| Me.MinkowskiReLU(inplace=True), |
| SparseBasicBlock3D(num_output_features, num_output_features), |
| Me.MinkowskiConvolution(num_output_features, 1, kernel_size=3, bias=True, dimension=3), |
| ), |
| Me.MinkowskiLinear(num_features * 2, 1), |
| nn.Linear(num_features * 4, 1), |
| ]) |
|
|
| |
| self.level_segm_embeddings = nn.ModuleList([ |
| nn.Sequential( |
| Me.MinkowskiInstanceNorm(num_output_features), |
| Me.MinkowskiReLU(inplace=True), |
| SparseBasicBlock3D(num_output_features, num_output_features), |
| ), |
| SparseBasicBlock3D(num_features * 3, num_features * 3), |
| nn.Sequential( |
| BasicBlock3D(num_features * 4, num_features * 4), |
| BasicBlock3D(num_features * 4, num_features * 4), |
| ) |
| ]) |
| self.level_segm_query_projection = nn.ModuleList([ |
| nn.Linear(mask_dim, num_output_features), |
| nn.Linear(mask_dim, num_features * 3), |
| nn.Linear(mask_dim, num_features * 4), |
| ]) |
|
|
| |
| self.geometry_head = nn.Sequential( |
| Me.MinkowskiInstanceNorm(num_output_features), |
| Me.MinkowskiReLU(inplace=True), |
| SparseBasicBlock3D(num_output_features, num_output_features), |
| Me.MinkowskiConvolution(num_output_features, 1, kernel_size=3, bias=True, dimension=3), |
| ) |
|
|
| self.register_buffer("frustum_dimensions", torch.tensor(frustum_dims), persistent=False) |
|
|
| @staticmethod |
| def forward_sparse_segm(segm_features, queries): |
| """Forward pass for sparse segmentation.""" |
| features = segm_features.decomposed_features |
| segms = torch.cat( |
| [torch.mm(features[idx], queries[idx].T) for idx in range(len(features))], dim=0 |
| ) |
| return Me.SparseTensor( |
| segms, |
| coordinate_manager=segm_features.coordinate_manager, |
| coordinate_map_key=segm_features.coordinate_map_key, |
| ) |
|
|
| @staticmethod |
| def make_encoder(input_dim, output_dim, is_sparse=True): |
| """Make encoder module.""" |
| if is_sparse: |
| downsample = nn.Sequential( |
| Me.MinkowskiConvolution( |
| input_dim, output_dim, kernel_size=4, stride=2, bias=True, dimension=3 |
| ), |
| Me.MinkowskiInstanceNorm(output_dim), |
| ) |
| module = nn.Sequential( |
| SparseBasicBlock3D(input_dim, output_dim, stride=2, downsample=downsample), |
| SparseBasicBlock3D(output_dim, output_dim), |
| ) |
| else: |
| downsample = nn.Conv3d( |
| input_dim, output_dim, |
| kernel_size=4, stride=2, |
| padding=1, bias=False |
| ) |
| module = nn.Sequential( |
| BasicBlock3D(input_dim, output_dim, stride=2, downsample=downsample), |
| BasicBlock3D(output_dim, output_dim), |
| ) |
| return module |
|
|
| @staticmethod |
| def make_decoder(input_dim, output_dim, is_sparse=True, extra_layers: Optional[List] = None): |
| """Make decoder module.""" |
| if extra_layers is None: |
| extra_layers = [] |
| if is_sparse: |
| return nn.Sequential( |
| Me.MinkowskiConvolutionTranspose( |
| input_dim, output_dim, kernel_size=4, |
| stride=2, bias=False, dimension=3, expand_coordinates=True |
| ), |
| Me.MinkowskiInstanceNorm(output_dim), |
| Me.MinkowskiReLU(inplace=True), |
| *extra_layers, |
| ) |
| else: |
| return nn.Sequential( |
| nn.ConvTranspose3d(input_dim, output_dim, kernel_size=4, stride=2, padding=1, bias=False), |
| nn.InstanceNorm3d(output_dim), |
| nn.ReLU(inplace=True), |
| *extra_layers, |
| ) |
|
|
| @staticmethod |
| def make_adapter(input_dim, output_dim, extra_layers: Optional[List] = None): |
| """Make adapter module.""" |
| if extra_layers is None: |
| extra_layers = [] |
| downsample = nn.Sequential( |
| Me.MinkowskiConvolution(input_dim, output_dim, kernel_size=1, stride=1, bias=True, dimension=3), |
| Me.MinkowskiInstanceNorm(output_dim), |
| ) |
| return nn.Sequential( |
| SparseBasicBlock3D(input_dim, output_dim, downsample=downsample), |
| *extra_layers, |
| ) |
|
|
| def forward( |
| self, ms_features: List[Me.SparseTensor], |
| features: Me.SparseTensor, segm_queries, frustum_mask |
| ): |
| """Forward pass.""" |
| start_dim = 0 |
| encoded_inputs = [] |
| cm = features.coordinate_manager |
| key = features.coordinate_map_key |
| for dim, encoder in zip(self.input_dims, self.input_encoders): |
| encoded_inputs.append( |
| encoder(Me.SparseTensor( |
| features.F[:, start_dim:start_dim + dim], coordinate_manager=cm, coordinate_map_key=key |
| )) |
| ) |
| start_dim += dim |
| encoded_inputs = Me.cat(*encoded_inputs) |
|
|
| lvls = len(self.level_encoders) |
|
|
| |
| encoder_outputs = [] |
| encoder_inputs = [encoded_inputs] |
|
|
| for idx in range(len(self.level_encoders)): |
| encoded = self.level_encoders[idx](encoder_inputs[idx]) |
| if self.use_ms_features and idx < len(self.feature_adapters): |
| feat = self.feature_adapters[idx](ms_features[idx]) |
|
|
| if isinstance(encoded, torch.Tensor): |
| encoded = encoded + feat |
| else: |
| feat = Me.SparseTensor( |
| feat.F, coordinates=feat.C, |
| tensor_stride=feat.tensor_stride, |
| coordinate_manager=encoded.coordinate_manager |
| ) |
| encoded = encoded + feat |
|
|
| encoder_outputs.append(encoded) |
|
|
| if idx < lvls - 1: |
| encoder_inputs.append(self.enc_level_conversion[idx](encoded)) |
|
|
| |
| decoder_outputs = [] |
| decoder_inputs = [encoder_outputs[-1]] |
| pred_occupancies = [] |
| pred_segms = [] |
| pred_geometry = None |
|
|
| |
| for idx in reversed(range(lvls)): |
| decoded = self.level_decoders[idx](decoder_inputs[lvls - 1 - idx]) |
| decoder_outputs.append(decoded) |
|
|
| if idx <= 1: |
| |
| occupancy = self.level_occupancy_heads[idx](decoded) |
| |
| valid_mask = ( |
| (occupancy.C[:, 1:] >= 0) & (occupancy.C[:, 1:] < self.frustum_dimensions) |
| ).all(-1) |
| pred_occupancies.append(Me.MinkowskiPruning()(occupancy, valid_mask)) |
| pruning_mask = (Me.MinkowskiSigmoid()(occupancy).F.squeeze(-1) > 0.5) & valid_mask |
| sparse_out = Me.MinkowskiPruning()(decoded, pruning_mask) |
|
|
| if idx > 0: |
| |
| sparse_out = sparse_cat_union(encoder_outputs[idx - 1], sparse_out) |
| valid_mask = ( |
| (sparse_out.C[:, 1:] >= 0) & (sparse_out.C[:, 1:] < self.frustum_dimensions) |
| ).all(-1) |
| decoder_inputs.append(Me.MinkowskiPruning()(sparse_out, valid_mask)) |
| else: |
| |
| pred_geometry = self.geometry_head(sparse_out) |
| predicted_values = pred_geometry.F |
| predicted_values = torch.clamp(predicted_values, 0.0, self.truncation) |
| pred_geometry = Me.SparseTensor( |
| predicted_values, |
| coordinate_manager=pred_geometry.coordinate_manager, |
| coordinate_map_key=pred_geometry.coordinate_map_key, |
| ) |
| valid_mask = ( |
| (pred_geometry.C[:, 1:] >= 0) & (pred_geometry.C[:, 1:] < self.frustum_dimensions) |
| ).all(-1) |
| pred_geometry = Me.MinkowskiPruning()(pred_geometry, valid_mask) |
|
|
| queries = self.level_segm_query_projection[idx](segm_queries) |
| segm_features = self.level_segm_embeddings[idx](sparse_out) |
| pred_segm = self.forward_sparse_segm(segm_features, queries) |
| valid_mask = ( |
| (pred_segm.C[:, 1:] >= 0) & (pred_segm.C[:, 1:] < self.frustum_dimensions) |
| ).all(-1) |
| pred_segms.append(Me.MinkowskiPruning()(pred_segm, valid_mask)) |
|
|
| elif idx == 2: |
| |
| decoded = torch.cat([encoder_inputs[idx], decoded], dim=1) |
| occupancy = self.level_occupancy_heads[idx](decoded.permute(0, 2, 3, 4, 1)).squeeze(-1) |
| pred_occupancies.append(occupancy.masked_fill(~frustum_mask.squeeze(1), -torch.inf)) |
|
|
| queries = self.level_segm_query_projection[idx](segm_queries) |
| segm_features = self.level_segm_embeddings[idx](decoded) |
| pred_segm = torch.einsum("bqc,bchwd->bqhwd", queries, segm_features) |
| pred_segms.append(pred_segm.masked_fill(~frustum_mask, -torch.inf)) |
|
|
| pruning_mask = (occupancy.sigmoid() > 0.5) & frustum_mask.squeeze(1) |
| coords = pruning_mask.nonzero() |
| sparse_out = decoded[coords[:, 0], :, coords[:, 1], coords[:, 2], coords[:, 3]] |
| encoded = encoder_outputs[idx - 1] |
| stride = encoded.tensor_stride |
| coords = coords.clone() |
| coords[:, 1:] *= torch.tensor(stride, device=coords.device) |
| sparse_out = Me.SparseTensor( |
| sparse_out, coordinates=coords.int().contiguous(), |
| tensor_stride=stride, coordinate_manager=cm |
| ) |
| decoder_inputs.append(sparse_cat_union(encoded, sparse_out)) |
| else: |
| decoder_inputs.append(torch.cat([encoder_inputs[idx], decoded], dim=1)) |
|
|
| return { |
| "pred_geometry": pred_geometry, |
| "pred_occupancies": pred_occupancies, |
| "pred_segms": pred_segms, |
| } |
|
|