|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Sparse tensor utils.""" |
|
|
|
|
|
import torch |
|
|
import MinkowskiEngine as Me |
|
|
import torch.nn.functional as F |
|
|
from typing import Optional, Tuple, Dict |
|
|
|
|
|
|
|
|
def sparse_cat_union(a: Me.SparseTensor, b: Me.SparseTensor): |
|
|
"""Sparse cat union two sparse tensors.""" |
|
|
cm = a.coordinate_manager |
|
|
stride = a.tensor_stride |
|
|
assert cm == b.coordinate_manager, "different coords_man" |
|
|
assert a.tensor_stride == b.tensor_stride, "different tensor_stride" |
|
|
|
|
|
|
|
|
if a.F.size(0) == 0 or a.F.numel() == 0: |
|
|
return b |
|
|
if b.F.size(0) == 0 or b.F.numel() == 0: |
|
|
return a |
|
|
|
|
|
try: |
|
|
feats_a = F.pad(a.F, (0, b.F.shape[1])) |
|
|
except Exception as e: |
|
|
print("Warning: Got error in feats_a:", e) |
|
|
return a |
|
|
try: |
|
|
feats_b = F.pad(b.F, (a.F.shape[1], 0)) |
|
|
except Exception as e: |
|
|
print("Warning: Got error in feats_b:", e) |
|
|
return b |
|
|
|
|
|
new_a = Me.SparseTensor( |
|
|
features=feats_a, |
|
|
coordinate_map_key=a.coordinate_key, |
|
|
coordinate_manager=cm, |
|
|
tensor_stride=stride, |
|
|
) |
|
|
|
|
|
new_b = Me.SparseTensor( |
|
|
features=feats_b, |
|
|
coordinate_map_key=b.coordinate_key, |
|
|
coordinate_manager=cm, |
|
|
tensor_stride=stride, |
|
|
) |
|
|
|
|
|
return new_a + new_b |
|
|
|
|
|
|
|
|
def to_dense( |
|
|
tensor: Me.SparseTensor, |
|
|
shape: Optional[torch.Size] = None, |
|
|
min_coordinate: Optional[torch.IntTensor] = None, |
|
|
contract_stride: bool = True, |
|
|
default_value: float = 0.0 |
|
|
) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]: |
|
|
"""Convert the :attr:`MinkowskiEngine.SparseTensor` to a torch dense |
|
|
tensor. |
|
|
Args: |
|
|
:attr:`shape` (torch.Size, optional): The size of the output tensor. |
|
|
:attr:`min_coordinate` (torch.IntTensor, optional): The min |
|
|
coordinates of the output sparse tensor. Must be divisible by the |
|
|
current :attr:`tensor_stride`. If 0 is given, it will use the origin for the min coordinate. |
|
|
:attr:`contract_stride` (bool, optional): The output coordinates |
|
|
will be divided by the tensor stride to make features spatially |
|
|
contiguous. True by default. |
|
|
Returns: |
|
|
:attr:`tensor` (torch.Tensor): the torch tensor with size `[Batch |
|
|
Dim, Feature Dim, Spatial Dim..., Spatial Dim]`. The coordinate of |
|
|
each feature can be accessed via `min_coordinate + tensor_stride * |
|
|
[the coordinate of the dense tensor]`. |
|
|
:attr:`min_coordinate` (torch.IntTensor): the D-dimensional vector |
|
|
defining the minimum coordinate of the output tensor. |
|
|
:attr:`tensor_stride` (torch.IntTensor): the D-dimensional vector |
|
|
defining the stride between tensor elements. |
|
|
""" |
|
|
if min_coordinate is not None: |
|
|
assert isinstance(min_coordinate, torch.IntTensor) |
|
|
assert min_coordinate.numel() == tensor._D |
|
|
if shape is not None: |
|
|
assert isinstance(shape, torch.Size) |
|
|
assert len(shape) == tensor._D + 2 |
|
|
if shape[1] != tensor._F.size(1): |
|
|
shape = torch.Size([shape[0], tensor._F.size(1), *[s for s in shape[2:]]]) |
|
|
|
|
|
|
|
|
if tensor.__len__() == 0: |
|
|
assert shape is not None, "shape is required to densify an empty tensor" |
|
|
return ( |
|
|
torch.zeros(shape, dtype=tensor.dtype, device=tensor.device), |
|
|
torch.zeros(tensor._D, dtype=torch.int32, device=tensor.device), |
|
|
tensor.tensor_stride, |
|
|
) |
|
|
|
|
|
|
|
|
tensor_stride = torch.IntTensor(tensor.tensor_stride).to(tensor.device) |
|
|
|
|
|
|
|
|
batch_indices = tensor.C[:, 0] |
|
|
|
|
|
if min_coordinate is None: |
|
|
min_coordinate, _ = tensor.C.min(0, keepdim=True) |
|
|
min_coordinate = min_coordinate[:, 1:] |
|
|
if not torch.all(min_coordinate >= 0): |
|
|
raise ValueError( |
|
|
f"Coordinate has a negative value: {min_coordinate}. Please provide min_coordinate argument" |
|
|
) |
|
|
coords = tensor.C[:, 1:] |
|
|
elif isinstance(min_coordinate, int) and min_coordinate == 0: |
|
|
coords = tensor.C[:, 1:] |
|
|
else: |
|
|
min_coordinate = min_coordinate.to(tensor.device) |
|
|
if min_coordinate.ndim == 1: |
|
|
min_coordinate = min_coordinate.unsqueeze(0) |
|
|
coords = tensor.C[:, 1:] - min_coordinate |
|
|
|
|
|
assert ( |
|
|
min_coordinate % tensor_stride |
|
|
).sum() == 0, "The minimum coordinates must be divisible by the tensor stride." |
|
|
|
|
|
if coords.ndim == 1: |
|
|
coords = coords.unsqueeze(1) |
|
|
|
|
|
|
|
|
if contract_stride: |
|
|
coords = torch.div(coords, tensor_stride, rounding_mode="floor") |
|
|
|
|
|
nchannels = tensor.F.size(1) |
|
|
if shape is None: |
|
|
size = coords.max(0)[0] + 1 |
|
|
shape = torch.Size( |
|
|
[batch_indices.max() + 1, nchannels, *size.cpu().numpy()] |
|
|
) |
|
|
|
|
|
dense_F = torch.full( |
|
|
shape, dtype=tensor.F.dtype, |
|
|
device=tensor.F.device, fill_value=default_value |
|
|
) |
|
|
|
|
|
tcoords = coords.t().long() |
|
|
batch_indices = batch_indices.long() |
|
|
|
|
|
indices = (batch_indices, slice(None), *tcoords) |
|
|
dense_F[indices] = tensor.F |
|
|
|
|
|
tensor_stride = torch.IntTensor(tensor.tensor_stride) |
|
|
return dense_F, min_coordinate, tensor_stride |
|
|
|
|
|
|
|
|
def _thicken_grid(grid, grid_dims, frustum_mask): |
|
|
"""Thicken grid.""" |
|
|
device = frustum_mask.device |
|
|
offsets = torch.nonzero(torch.ones(3, 3, 3, device=device)).long() |
|
|
locs_grid = grid.nonzero(as_tuple=False) |
|
|
locs = locs_grid.unsqueeze(1).repeat(1, 27, 1) |
|
|
locs += offsets |
|
|
locs = locs.view(-1, 3) |
|
|
masks = ((locs >= 0) & (locs < torch.as_tensor(grid_dims, device=device))).all(-1) |
|
|
locs = locs[masks] |
|
|
|
|
|
thicken = torch.zeros(grid_dims, dtype=torch.bool, device=device) |
|
|
thicken[locs[:, 0], locs[:, 1], locs[:, 2]] = True |
|
|
|
|
|
thicken = thicken & frustum_mask |
|
|
|
|
|
return thicken |
|
|
|
|
|
|
|
|
def prepare_instance_masks_thicken( |
|
|
instances: torch.Tensor, |
|
|
semantic_mapping: Dict[int, int], |
|
|
distance_field: torch.Tensor, |
|
|
frustum_mask: torch.Tensor, |
|
|
iso_value: float = 1.0, |
|
|
truncation: float = 3.0, |
|
|
downsample_factor: int = 1 |
|
|
) -> Dict[int, Tuple[torch.Tensor, int]]: |
|
|
"""Prepare instance masks thicken.""" |
|
|
|
|
|
assert isinstance(downsample_factor, int) and 256 % downsample_factor == 0 |
|
|
grid_dims = [256, 256, 256] |
|
|
need_rescale = downsample_factor != 1 |
|
|
if need_rescale: |
|
|
grid_dims = (torch.as_tensor(grid_dims) // downsample_factor).tolist() |
|
|
frustum_mask = F.interpolate(frustum_mask[None, None].float(), |
|
|
size=grid_dims, mode="nearest").squeeze(0, 1).bool() |
|
|
|
|
|
instance_information = {} |
|
|
|
|
|
for instance_id, semantic_class in semantic_mapping.items(): |
|
|
instance_mask: torch.Tensor = (instances == instance_id) |
|
|
instance_distance_field = torch.full_like( |
|
|
instance_mask, |
|
|
dtype=torch.float, |
|
|
fill_value=truncation |
|
|
) |
|
|
instance_distance_field[instance_mask] = distance_field.squeeze()[instance_mask] |
|
|
instance_distance_field_masked = instance_distance_field.abs() < iso_value |
|
|
|
|
|
if need_rescale: |
|
|
instance_distance_field_masked = F.max_pool3d( |
|
|
instance_distance_field_masked[None, None].float(), |
|
|
kernel_size=downsample_factor + 1, |
|
|
stride=downsample_factor, |
|
|
padding=1 |
|
|
).squeeze(0, 1).bool() |
|
|
|
|
|
|
|
|
instance_grid = _thicken_grid( |
|
|
instance_distance_field_masked, |
|
|
grid_dims, |
|
|
frustum_mask |
|
|
) |
|
|
instance_grid: torch.Tensor = instance_grid.to(torch.device("cpu"), non_blocking=True) |
|
|
instance_information[instance_id] = instance_grid, semantic_class |
|
|
|
|
|
return instance_information |
|
|
|
|
|
|
|
|
def mask_invalid_sparse_voxels( |
|
|
grid: Me.SparseTensor, |
|
|
mask=None, frustum_dim=[256, 256, 256] |
|
|
) -> Me.SparseTensor: |
|
|
"""Mask invalid sparse voxels.""" |
|
|
|
|
|
valid_mask = (grid.C[:, 1] < frustum_dim[0] - 1) & (grid.C[:, 1] >= 0) & \ |
|
|
(grid.C[:, 2] < frustum_dim[1] - 1) & (grid.C[:, 2] >= 0) & \ |
|
|
(grid.C[:, 3] < frustum_dim[2] - 1) & (grid.C[:, 3] >= 0) |
|
|
if mask is not None: |
|
|
valid_mask = valid_mask * mask |
|
|
num_valid_coordinates = valid_mask.sum() |
|
|
|
|
|
if num_valid_coordinates == 0: |
|
|
return {}, {} |
|
|
|
|
|
num_masked_voxels = grid.C.size(0) - num_valid_coordinates |
|
|
grids_needs_to_be_pruned = num_masked_voxels > 0 |
|
|
|
|
|
|
|
|
if grids_needs_to_be_pruned: |
|
|
grid = Me.MinkowskiPruning()(grid, valid_mask) |
|
|
|
|
|
return grid |
|
|
|