File size: 5,462 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from torch import Tensor
from mmdet3d.models import EncoderDecoder3D
from mmdet3d.registry import MODELS
from mmdet3d.structures import PointData
from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList
@MODELS.register_module()
class RangeImageSegmentor(EncoderDecoder3D):
def loss(self, batch_inputs_dict: dict,
batch_data_samples: SampleList) -> Dict[str, Tensor]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs_dict (dict): Input sample dict which
includes 'points' and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Dict[str, Tensor]: A dictionary of loss components.
"""
# extract features using backbone
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, batch_data_samples)
losses.update(loss_aux)
return losses
def predict(self,
batch_inputs_dict: dict,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Simple test with single scene.
Args:
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
rescale (bool): Whether transform to original number of points.
Will be used for voxelization based segmentors.
Defaults to True.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
batch_input_metas = []
for data_sample in batch_data_samples:
batch_input_metas.append(data_sample.metainfo)
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
seg_labels_list = self.decode_head.predict(x, batch_input_metas,
self.test_cfg)
return self.postprocess_result(seg_labels_list, batch_data_samples)
def _forward(self,
batch_inputs_dict: dict,
batch_data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
batch_inputs_dict (dict): Input sample dict which includes 'points'
and 'imgs' keys.
- points (List[Tensor]): Point cloud of each sample.
- imgs (Tensor, optional): Image tensor has shape (B, C, H, W).
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
Tensor: Forward output of model without any post-processes.
"""
imgs = batch_inputs_dict['imgs']
x = self.extract_feat(imgs)
return self.decode_head.forward(x)
def postprocess_result(self, seg_labels_list: List[Tensor],
batch_data_samples: SampleList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args:
seg_labels_list (List[Tensor]): List of segmentation results,
seg_logits from model of each input point clouds sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
`gt_pts_seg`.
Returns:
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
for i, seg_pred in enumerate(seg_labels_list):
batch_data_samples[i].set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
return batch_data_samples
|