File size: 5,615 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from typing import Dict, List, Optional
import torch
from torch import Tensor
from mmdet3d.models import Base3DDetector
from mmdet3d.registry import MODELS
from mmdet3d.structures import Det3DDataSample
@MODELS.register_module()
class DSVT(Base3DDetector):
"""DSVT detector."""
def __init__(self,
voxel_encoder: Optional[dict] = None,
middle_encoder: Optional[dict] = None,
backbone: Optional[dict] = None,
neck: Optional[dict] = None,
map2bev: Optional[dict] = None,
bbox_head: Optional[dict] = None,
train_cfg: Optional[dict] = None,
test_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None,
data_preprocessor: Optional[dict] = None,
**kwargs):
super(DSVT, self).__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs)
if voxel_encoder:
self.voxel_encoder = MODELS.build(voxel_encoder)
if middle_encoder:
self.middle_encoder = MODELS.build(middle_encoder)
if backbone:
self.backbone = MODELS.build(backbone)
self.map2bev = MODELS.build(map2bev)
if neck is not None:
self.neck = MODELS.build(neck)
if bbox_head:
bbox_head.update(train_cfg=train_cfg, test_cfg=test_cfg)
self.bbox_head = MODELS.build(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
@property
def with_bbox(self):
"""bool: Whether the detector has a 3D box head."""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
@property
def with_backbone(self):
"""bool: Whether the detector has a 3D backbone."""
return hasattr(self, 'backbone') and self.backbone is not None
@property
def with_voxel_encoder(self):
"""bool: Whether the detector has a voxel encoder."""
return hasattr(self,
'voxel_encoder') and self.voxel_encoder is not None
@property
def with_middle_encoder(self):
"""bool: Whether the detector has a middle encoder."""
return hasattr(self,
'middle_encoder') and self.middle_encoder is not None
def _forward(self):
pass
def extract_feat(self, batch_inputs_dict: dict) -> tuple:
"""Extract features from images and points.
Args:
batch_inputs_dict (dict): Dict of batch inputs. It
contains
- points (List[tensor]): Point cloud of multiple inputs.
- imgs (tensor): Image tensor with shape (B, C, H, W).
Returns:
tuple: Two elements in tuple arrange as
image features and point cloud features.
"""
batch_out_dict = self.voxel_encoder(batch_inputs_dict)
batch_out_dict = self.middle_encoder(batch_out_dict)
batch_out_dict = self.map2bev(batch_out_dict)
multi_feats = self.backbone(batch_out_dict['spatial_features'])
feats = self.neck(multi_feats)
return feats
def loss(self, batch_inputs_dict: Dict[List, torch.Tensor],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""
Args:
batch_inputs_dict (dict): The model input dict which include
'points' and `imgs` keys.
- points (list[torch.Tensor]): Point cloud of each sample.
- imgs (torch.Tensor): Tensor of batch images, has shape
(B, C, H ,W)
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`, .
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
pts_feats = self.extract_feat(batch_inputs_dict)
losses = dict()
loss = self.bbox_head.loss(pts_feats, batch_data_samples)
losses.update(loss)
return losses
def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
"""Forward of testing.
Args:
batch_inputs_dict (dict): The model input dict which include
'points' keys.
- points (list[torch.Tensor]): Point cloud of each sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance_3d`.
Returns:
list[:obj:`Det3DDataSample`]: Detection results of the
input sample. Each Det3DDataSample usually contain
'pred_instances_3d'. And the ``pred_instances_3d`` usually
contains following keys.
- scores_3d (Tensor): Classification scores, has a shape
(num_instances, )
- labels_3d (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes,
contains a tensor with shape (num_instances, 7).
"""
pts_feats = self.extract_feat(batch_inputs_dict)
results_list_3d = self.bbox_head.predict(pts_feats, batch_data_samples)
detsamples = self.add_pred_to_datasample(batch_data_samples,
results_list_3d)
return detsamples
|