mm3dtest / projects /CENet /cenet /range_image_segmentor.py
giantmonkeyTC
2344
34d1f8b
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from torch import Tensor
from mmdet3d.models import EncoderDecoder3D
from mmdet3d.registry import MODELS
from mmdet3d.structures import PointData
from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList
@MODELS.register_module()
class RangeImageSegmentor(EncoderDecoder3D):
def loss(self, batch_inputs_dict: dict,
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
# extract features using backbone
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, batch_data_samples)
losses.update(loss_aux)
return losses
def predict(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Simple test with single scene.
Args:
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
seg_labels_list = self.decode_head.predict(x, batch_input_metas,
self.test_cfg)
return self.postprocess_result(seg_labels_list, batch_data_samples)
def _forward(self,
batch_inputs_dict: dict,
batch_data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
return self.decode_head.forward(x)
def postprocess_result(self, seg_labels_list: List[Tensor],
batch_data_samples: SampleList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args:
seg_labels_list (List[Tensor]): List of segmentation results,
seg_logits from model of each input point clouds sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
for i, seg_pred in enumerate(seg_labels_list):
batch_data_samples[i].set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
return batch_data_samples