|
|
|
from typing import Dict, List |
|
|
|
from torch import Tensor |
|
|
|
from mmdet3d.models import EncoderDecoder3D |
|
from mmdet3d.registry import MODELS |
|
from mmdet3d.structures import PointData |
|
from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList |
|
|
|
|
|
@MODELS.register_module() |
|
class RangeImageSegmentor(EncoderDecoder3D): |
|
|
|
def loss(self, batch_inputs_dict: dict, |
|
batch_data_samples: SampleList) -> Dict[str, Tensor]: |
|
"""Calculate losses from a batch of inputs and data samples. |
|
|
|
Args: |
|
batch_inputs_dict (dict): Input sample dict which |
|
includes 'points' and 'imgs' keys. |
|
|
|
- points (List[Tensor]): Point cloud of each sample. |
|
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W). |
|
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data |
|
samples. It usually includes information such as `metainfo` and |
|
`gt_pts_seg`. |
|
|
|
Returns: |
|
Dict[str, Tensor]: A dictionary of loss components. |
|
""" |
|
|
|
imgs = batch_inputs_dict['imgs'] |
|
x = self.extract_feat(imgs) |
|
|
|
losses = dict() |
|
|
|
loss_decode = self._decode_head_forward_train(x, batch_data_samples) |
|
losses.update(loss_decode) |
|
|
|
if self.with_auxiliary_head: |
|
loss_aux = self._auxiliary_head_forward_train( |
|
x, batch_data_samples) |
|
losses.update(loss_aux) |
|
return losses |
|
|
|
def predict(self, |
|
batch_inputs_dict: dict, |
|
batch_data_samples: SampleList, |
|
rescale: bool = True) -> SampleList: |
|
"""Simple test with single scene. |
|
|
|
Args: |
|
batch_inputs_dict (dict): Input sample dict which includes 'points' |
|
and 'imgs' keys. |
|
|
|
- points (List[Tensor]): Point cloud of each sample. |
|
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W). |
|
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data |
|
samples. It usually includes information such as `metainfo` and |
|
`gt_pts_seg`. |
|
rescale (bool): Whether transform to original number of points. |
|
Will be used for voxelization based segmentors. |
|
Defaults to True. |
|
|
|
Returns: |
|
List[:obj:`Det3DDataSample`]: Segmentation results of the input |
|
points. Each Det3DDataSample usually contains: |
|
|
|
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic |
|
segmentation. |
|
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic |
|
segmentation before normalization. |
|
""" |
|
|
|
|
|
|
|
batch_input_metas = [] |
|
for data_sample in batch_data_samples: |
|
batch_input_metas.append(data_sample.metainfo) |
|
|
|
imgs = batch_inputs_dict['imgs'] |
|
x = self.extract_feat(imgs) |
|
seg_labels_list = self.decode_head.predict(x, batch_input_metas, |
|
self.test_cfg) |
|
|
|
return self.postprocess_result(seg_labels_list, batch_data_samples) |
|
|
|
def _forward(self, |
|
batch_inputs_dict: dict, |
|
batch_data_samples: OptSampleList = None) -> Tensor: |
|
"""Network forward process. |
|
|
|
Args: |
|
batch_inputs_dict (dict): Input sample dict which includes 'points' |
|
and 'imgs' keys. |
|
|
|
- points (List[Tensor]): Point cloud of each sample. |
|
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W). |
|
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data |
|
samples. It usually includes information such as `metainfo` and |
|
`gt_pts_seg`. |
|
|
|
Returns: |
|
Tensor: Forward output of model without any post-processes. |
|
""" |
|
imgs = batch_inputs_dict['imgs'] |
|
x = self.extract_feat(imgs) |
|
return self.decode_head.forward(x) |
|
|
|
def postprocess_result(self, seg_labels_list: List[Tensor], |
|
batch_data_samples: SampleList) -> SampleList: |
|
"""Convert results list to `Det3DDataSample`. |
|
|
|
Args: |
|
seg_labels_list (List[Tensor]): List of segmentation results, |
|
seg_logits from model of each input point clouds sample. |
|
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data |
|
samples. It usually includes information such as `metainfo` and |
|
`gt_pts_seg`. |
|
|
|
Returns: |
|
List[:obj:`Det3DDataSample`]: Segmentation results of the input |
|
points. Each Det3DDataSample usually contains: |
|
|
|
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic |
|
segmentation. |
|
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic |
|
segmentation before normalization. |
|
""" |
|
|
|
for i, seg_pred in enumerate(seg_labels_list): |
|
batch_data_samples[i].set_data( |
|
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})}) |
|
return batch_data_samples |
|
|