nvpanoptix-3d / nvpanoptix_3d /utils /sparse_tensor.py
vpraveen-nv's picture
Update model inference code and environment setup instructions (#4)
f4a0919 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sparse tensor utils."""
import torch
import MinkowskiEngine as Me
import torch.nn.functional as F
from typing import Optional, Tuple, Dict
def sparse_cat_union(a: Me.SparseTensor, b: Me.SparseTensor):
"""Sparse cat union two sparse tensors."""
cm = a.coordinate_manager
stride = a.tensor_stride
assert cm == b.coordinate_manager, "different coords_man"
assert a.tensor_stride == b.tensor_stride, "different tensor_stride"
# handle empty tensors - if one is empty, return the other
if a.F.size(0) == 0 or a.F.numel() == 0:
return b
if b.F.size(0) == 0 or b.F.numel() == 0:
return a
# handle the error
try:
feats_a = F.pad(a.F, (0, b.F.shape[1]))
except Exception as e:
print("Warning: Got error in feats_a:", e)
return a
try:
feats_b = F.pad(b.F, (a.F.shape[1], 0))
except Exception as e:
print("Warning: Got error in feats_b:", e)
return b
new_a = Me.SparseTensor(
features=feats_a,
coordinate_map_key=a.coordinate_key,
coordinate_manager=cm,
tensor_stride=stride,
)
new_b = Me.SparseTensor(
features=feats_b,
coordinate_map_key=b.coordinate_key,
coordinate_manager=cm,
tensor_stride=stride,
)
return new_a + new_b
def to_dense(
tensor: Me.SparseTensor,
shape: Optional[torch.Size] = None,
min_coordinate: Optional[torch.IntTensor] = None,
contract_stride: bool = True,
default_value: float = 0.0
) -> Tuple[torch.Tensor, torch.IntTensor, torch.IntTensor]:
"""Convert the :attr:`MinkowskiEngine.SparseTensor` to a torch dense
tensor.
Args:
:attr:`shape` (torch.Size, optional): The size of the output tensor.
:attr:`min_coordinate` (torch.IntTensor, optional): The min
coordinates of the output sparse tensor. Must be divisible by the
current :attr:`tensor_stride`. If 0 is given, it will use the origin for the min coordinate.
:attr:`contract_stride` (bool, optional): The output coordinates
will be divided by the tensor stride to make features spatially
contiguous. True by default.
Returns:
:attr:`tensor` (torch.Tensor): the torch tensor with size `[Batch
Dim, Feature Dim, Spatial Dim..., Spatial Dim]`. The coordinate of
each feature can be accessed via `min_coordinate + tensor_stride *
[the coordinate of the dense tensor]`.
:attr:`min_coordinate` (torch.IntTensor): the D-dimensional vector
defining the minimum coordinate of the output tensor.
:attr:`tensor_stride` (torch.IntTensor): the D-dimensional vector
defining the stride between tensor elements.
"""
if min_coordinate is not None:
assert isinstance(min_coordinate, torch.IntTensor)
assert min_coordinate.numel() == tensor._D
if shape is not None:
assert isinstance(shape, torch.Size)
assert len(shape) == tensor._D + 2 # batch and channel
if shape[1] != tensor._F.size(1):
shape = torch.Size([shape[0], tensor._F.size(1), *[s for s in shape[2:]]])
# exception handling for empty tensor
if tensor.__len__() == 0:
assert shape is not None, "shape is required to densify an empty tensor"
return (
torch.zeros(shape, dtype=tensor.dtype, device=tensor.device),
torch.zeros(tensor._D, dtype=torch.int32, device=tensor.device),
tensor.tensor_stride,
)
# use int tensor for all operations
tensor_stride = torch.IntTensor(tensor.tensor_stride).to(tensor.device)
# new coordinates
batch_indices = tensor.C[:, 0]
if min_coordinate is None:
min_coordinate, _ = tensor.C.min(0, keepdim=True)
min_coordinate = min_coordinate[:, 1:]
if not torch.all(min_coordinate >= 0):
raise ValueError(
f"Coordinate has a negative value: {min_coordinate}. Please provide min_coordinate argument"
)
coords = tensor.C[:, 1:]
elif isinstance(min_coordinate, int) and min_coordinate == 0:
coords = tensor.C[:, 1:]
else:
min_coordinate = min_coordinate.to(tensor.device)
if min_coordinate.ndim == 1:
min_coordinate = min_coordinate.unsqueeze(0)
coords = tensor.C[:, 1:] - min_coordinate
assert (
min_coordinate % tensor_stride
).sum() == 0, "The minimum coordinates must be divisible by the tensor stride."
if coords.ndim == 1:
coords = coords.unsqueeze(1)
# return the contracted tensor
if contract_stride:
coords = torch.div(coords, tensor_stride, rounding_mode="floor")
nchannels = tensor.F.size(1)
if shape is None:
size = coords.max(0)[0] + 1
shape = torch.Size(
[batch_indices.max() + 1, nchannels, *size.cpu().numpy()]
)
dense_F = torch.full(
shape, dtype=tensor.F.dtype,
device=tensor.F.device, fill_value=default_value
)
tcoords = coords.t().long()
batch_indices = batch_indices.long()
indices = (batch_indices, slice(None), *tcoords)
dense_F[indices] = tensor.F
tensor_stride = torch.IntTensor(tensor.tensor_stride)
return dense_F, min_coordinate, tensor_stride
def _thicken_grid(grid, grid_dims, frustum_mask):
"""Thicken grid."""
device = frustum_mask.device
offsets = torch.nonzero(torch.ones(3, 3, 3, device=device)).long()
locs_grid = grid.nonzero(as_tuple=False)
locs = locs_grid.unsqueeze(1).repeat(1, 27, 1)
locs += offsets
locs = locs.view(-1, 3)
masks = ((locs >= 0) & (locs < torch.as_tensor(grid_dims, device=device))).all(-1)
locs = locs[masks]
thicken = torch.zeros(grid_dims, dtype=torch.bool, device=device)
thicken[locs[:, 0], locs[:, 1], locs[:, 2]] = True
# frustum culling
thicken = thicken & frustum_mask
return thicken
def prepare_instance_masks_thicken(
instances: torch.Tensor,
semantic_mapping: Dict[int, int],
distance_field: torch.Tensor,
frustum_mask: torch.Tensor,
iso_value: float = 1.0,
truncation: float = 3.0,
downsample_factor: int = 1
) -> Dict[int, Tuple[torch.Tensor, int]]:
"""Prepare instance masks thicken."""
# check if downsample factor is valid
assert isinstance(downsample_factor, int) and 256 % downsample_factor == 0
grid_dims = [256, 256, 256]
need_rescale = downsample_factor != 1
if need_rescale:
grid_dims = (torch.as_tensor(grid_dims) // downsample_factor).tolist()
frustum_mask = F.interpolate(frustum_mask[None, None].float(),
size=grid_dims, mode="nearest").squeeze(0, 1).bool()
instance_information = {}
for instance_id, semantic_class in semantic_mapping.items():
instance_mask: torch.Tensor = (instances == instance_id)
instance_distance_field = torch.full_like(
instance_mask,
dtype=torch.float,
fill_value=truncation
)
instance_distance_field[instance_mask] = distance_field.squeeze()[instance_mask]
instance_distance_field_masked = instance_distance_field.abs() < iso_value
if need_rescale:
instance_distance_field_masked = F.max_pool3d(
instance_distance_field_masked[None, None].float(),
kernel_size=downsample_factor + 1,
stride=downsample_factor,
padding=1
).squeeze(0, 1).bool()
# instance_grid = instance_grid & frustum_mask
instance_grid = _thicken_grid(
instance_distance_field_masked,
grid_dims,
frustum_mask
)
instance_grid: torch.Tensor = instance_grid.to(torch.device("cpu"), non_blocking=True)
instance_information[instance_id] = instance_grid, semantic_class
return instance_information
def mask_invalid_sparse_voxels(
grid: Me.SparseTensor,
mask=None, frustum_dim=[256, 256, 256]
) -> Me.SparseTensor:
"""Mask invalid sparse voxels."""
# Mask out voxels which are outside of the grid
valid_mask = (grid.C[:, 1] < frustum_dim[0] - 1) & (grid.C[:, 1] >= 0) & \
(grid.C[:, 2] < frustum_dim[1] - 1) & (grid.C[:, 2] >= 0) & \
(grid.C[:, 3] < frustum_dim[2] - 1) & (grid.C[:, 3] >= 0)
if mask is not None:
valid_mask = valid_mask * mask
num_valid_coordinates = valid_mask.sum()
if num_valid_coordinates == 0:
return {}, {}
num_masked_voxels = grid.C.size(0) - num_valid_coordinates
grids_needs_to_be_pruned = num_masked_voxels > 0
# Fix: Only prune if there are invalid voxels
if grids_needs_to_be_pruned:
grid = Me.MinkowskiPruning()(grid, valid_mask)
return grid