import numpy as np import torch import torch.nn as nn from .models.data_processor import DataProcessor from .models.mean_vfe import MeanVFE from .models.spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt from .models.voxelnext_head import VoxelNeXtHead from .utils.image_projection import _proj_voxel_image from segment_anything import SamPredictor, sam_model_registry class VoxelNeXt(nn.Module): def __init__(self, model_cfg): super().__init__() point_cloud_range = np.array(model_cfg.POINT_CLOUD_RANGE, dtype=np.float32) self.data_processor = DataProcessor( model_cfg.DATA_PROCESSOR, point_cloud_range=point_cloud_range, training=False, num_point_features=len(model_cfg.USED_FEATURE_LIST) ) input_channels = model_cfg.get('INPUT_CHANNELS', 5) grid_size = np.array(model_cfg.get('GRID_SIZE', [1440, 1440, 40])) class_names = model_cfg.get('CLASS_NAMES') kernel_size_head = model_cfg.get('KERNEL_SIZE_HEAD', 1) self.point_cloud_range = torch.Tensor(model_cfg.get('POINT_CLOUD_RANGE', [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0])) self.voxel_size = torch.Tensor(model_cfg.get('VOXEL_SIZE', [0.075, 0.075, 0.2])) CLASS_NAMES_EACH_HEAD = model_cfg.get('CLASS_NAMES_EACH_HEAD') SEPARATE_HEAD_CFG = model_cfg.get('SEPARATE_HEAD_CFG') POST_PROCESSING = model_cfg.get('POST_PROCESSING') self.voxelization = MeanVFE() self.backbone_3d = VoxelResBackBone8xVoxelNeXt(input_channels, grid_size) self.dense_head = VoxelNeXtHead(class_names, self.point_cloud_range, self.voxel_size, kernel_size_head, CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING) class Model(nn.Module): def __init__(self, model_cfg, device="cuda"): super().__init__() sam_type = model_cfg.get('SAM_TYPE', "vit_b") sam_checkpoint = model_cfg.get('SAM_CHECKPOINT', "/data/sam_vit_b_01ec64.pth") sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint).to(device=device) self.sam_predictor = SamPredictor(sam) voxelnext_checkpoint = model_cfg.get('VOXELNEXT_CHECKPOINT', "/data/voxelnext_nuscenes_kernel1.pth") model_dict = torch.load(voxelnext_checkpoint) self.voxelnext = VoxelNeXt(model_cfg).to(device=device) self.voxelnext.load_state_dict(model_dict) self.point_features = {} self.device = device def image_embedding(self, image): self.sam_predictor.set_image(image) def point_embedding(self, data_dict, image_id): data_dict = self.voxelnext.data_processor.forward( data_dict=data_dict ) data_dict['voxels'] = torch.Tensor(data_dict['voxels']).to(self.device) data_dict['voxel_num_points'] = torch.Tensor(data_dict['voxel_num_points']).to(self.device) data_dict['voxel_coords'] = torch.Tensor(data_dict['voxel_coords']).to(self.device) data_dict = self.voxelnext.voxelization(data_dict) n_voxels = data_dict['voxel_coords'].shape[0] device = data_dict['voxel_coords'].device dtype = data_dict['voxel_coords'].dtype data_dict['voxel_coords'] = torch.cat([torch.zeros((n_voxels, 1), device=device, dtype=dtype), data_dict['voxel_coords']], dim=1) data_dict['batch_size'] = 1 if not image_id in self.point_features: data_dict = self.voxelnext.backbone_3d(data_dict) self.point_features[image_id] = data_dict else: data_dict = self.point_features[image_id] pred_dicts = self.voxelnext.dense_head(data_dict) voxel_coords = data_dict['out_voxels'][pred_dicts[0]['voxel_ids'].squeeze(-1)] * self.voxelnext.dense_head.feature_map_stride return pred_dicts, voxel_coords def generate_3D_box(self, lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=0.1): device = voxel_coords.device points_image, depth = _proj_voxel_image(voxel_coords, lidar2img_rt, self.voxelnext.voxel_size.to(device), self.voxelnext.point_cloud_range.to(device)) points = points_image.permute(1, 0).int().cpu().numpy() selected_voxels = torch.zeros_like(depth).squeeze(0) for i in range(points.shape[0]): point = points[i] if point[0] < 0 or point[1] < 0 or point[0] >= mask.shape[1] or point[1] >= mask.shape[0]: continue if mask[point[1], point[0]]: selected_voxels[i] = 1 mask_extra = (pred_dicts[0]['pred_scores'] > quality_score) if mask_extra.sum() == 0: print("no high quality 3D box related.") return None selected_voxels *= mask_extra if selected_voxels.sum() > 0: selected_box_id = pred_dicts[0]['pred_scores'][selected_voxels.bool()].argmax() selected_box = pred_dicts[0]['pred_boxes'][selected_voxels.bool()][selected_box_id] else: grid_x, grid_y = torch.meshgrid(torch.arange(mask.shape[0]), torch.arange(mask.shape[1])) mask_x, mask_y = grid_x[mask], grid_y[mask] mask_center = torch.Tensor([mask_y.float().mean(), mask_x.float().mean()]).to( pred_dicts[0]['pred_boxes'].device).unsqueeze(1) dist = ((points_image - mask_center) ** 2).sum(0) selected_id = dist[mask_extra].argmin() selected_box = pred_dicts[0]['pred_boxes'][mask_extra][selected_id] return selected_box def forward(self, image, point_dict, prompt_point, lidar2img_rt, image_id, quality_score=0.1): self.image_embedding(image) pred_dicts, voxel_coords = self.point_embedding(point_dict, image_id) masks, scores, _ = self.sam_predictor.predict(point_coords=prompt_point, point_labels=np.array([1])) mask = masks[0] box3d = self.generate_3D_box(lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=quality_score) return mask, box3d if __name__ == '__main__': cfg_dataset = 'nuscenes_dataset.yaml' cfg_model = 'config.yaml' dataset_cfg = cfg_from_yaml_file(cfg_dataset, cfg) model_cfg = cfg_from_yaml_file(cfg_model, cfg) nuscenes_dataset = NuScenesDataset(dataset_cfg) model = Model(model_cfg) index = 0 data_dict = nuscenes_dataset._get_points(index) model.point_embedding(data_dict)