|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Coordinate transform utils.""" |
|
|
|
|
|
import torch |
|
|
import MinkowskiEngine as Me |
|
|
from typing import List |
|
|
|
|
|
from ..reconstruction.frustum import \ |
|
|
generate_frustum, compute_camera2frustum_transform |
|
|
|
|
|
|
|
|
def transform_feat3d_coordinates( |
|
|
feat3d, intrinsic, |
|
|
image_size=(120, 160), |
|
|
depth_min=0.4, depth_max=6.0, |
|
|
voxel_size=0.03 |
|
|
): |
|
|
""" |
|
|
Transform feat3d coordinates to match Uni3D coordinate system |
|
|
|
|
|
Args: |
|
|
feat3d: Me.SparseTensor from occupancy-aware lifting |
|
|
intrinsic: Camera intrinsic matrix (4x4) |
|
|
image_size: tuple of (height, width) |
|
|
depth_min, depth_max: depth range |
|
|
voxel_size: voxel size in meters |
|
|
Returns: |
|
|
Me.SparseTensor with transformed coordinates |
|
|
""" |
|
|
device = feat3d.device |
|
|
coords = feat3d.C.clone() |
|
|
|
|
|
|
|
|
coords[:, 1:3] = 256 - coords[:, 1:3] |
|
|
batch_indices = coords[:, 0].unique() |
|
|
|
|
|
compute_once = True |
|
|
if intrinsic.dim() == 3: |
|
|
|
|
|
if len(batch_indices) > 1: |
|
|
compute_once = torch.allclose(intrinsic[0:1].expand_as(intrinsic), intrinsic, atol=1e-6) |
|
|
intrinsic_ref = intrinsic[0] if compute_once else None |
|
|
else: |
|
|
intrinsic_ref = intrinsic |
|
|
|
|
|
if compute_once: |
|
|
intrinsic_batch = intrinsic_ref |
|
|
intrinsic_inverse = torch.inverse(intrinsic_batch) |
|
|
frustum = generate_frustum(image_size, intrinsic_inverse, depth_min, depth_max) |
|
|
camera2frustum, padding_offsets = compute_camera2frustum_transform( |
|
|
frustum.to(device), voxel_size, |
|
|
frustum_dimensions=torch.tensor([256, 256, 256], device=device) |
|
|
) |
|
|
|
|
|
camera2frustum = camera2frustum.to(device) |
|
|
padding_offsets = padding_offsets.to(device) |
|
|
camera2frustum_inv = torch.inverse(camera2frustum).float() |
|
|
ones_offset = torch.tensor([1., 1., 1.], device=device) |
|
|
|
|
|
transformed_coords_list = [] |
|
|
|
|
|
for batch_idx in batch_indices: |
|
|
batch_mask = coords[:, 0] == batch_idx |
|
|
batch_coords = coords[batch_mask, 1:].float() |
|
|
|
|
|
|
|
|
if not compute_once: |
|
|
intrinsic_batch = intrinsic[int(batch_idx)] |
|
|
intrinsic_inverse = torch.inverse(intrinsic_batch) |
|
|
frustum = generate_frustum(image_size, intrinsic_inverse, depth_min, depth_max) |
|
|
camera2frustum, padding_offsets = compute_camera2frustum_transform( |
|
|
frustum.to(device), voxel_size, |
|
|
frustum_dimensions=torch.tensor([256, 256, 256], device=device) |
|
|
) |
|
|
camera2frustum = camera2frustum.float().to(device) |
|
|
padding_offsets = padding_offsets.to(device) |
|
|
camera2frustum_inv = torch.inverse(camera2frustum).float() |
|
|
ones_offset = torch.tensor([1., 1., 1.], device=device) |
|
|
|
|
|
|
|
|
batch_coords_adjusted = batch_coords - padding_offsets - ones_offset |
|
|
|
|
|
|
|
|
homogenous_coords = torch.cat([ |
|
|
batch_coords_adjusted, |
|
|
torch.ones(batch_coords_adjusted.shape[0], 1, device=device) |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
world_coords = torch.mm(camera2frustum_inv, homogenous_coords.t()) |
|
|
final_coords_homog = torch.mm(camera2frustum.float(), world_coords.float()) |
|
|
final_coords = final_coords_homog.t()[:, :3] |
|
|
|
|
|
|
|
|
final_coords = final_coords + padding_offsets |
|
|
|
|
|
|
|
|
batch_column = torch.full( |
|
|
(final_coords.shape[0], 1), |
|
|
batch_idx, |
|
|
device=device, |
|
|
dtype=torch.float32 |
|
|
) |
|
|
final_batch_coords = torch.cat([batch_column, final_coords], dim=1) |
|
|
transformed_coords_list.append(final_batch_coords) |
|
|
|
|
|
transformed_coords = torch.cat(transformed_coords_list, dim=0) |
|
|
|
|
|
transformed_feat3d = Me.SparseTensor( |
|
|
features=feat3d.F, |
|
|
coordinates=transformed_coords.int(), |
|
|
tensor_stride=feat3d.tensor_stride, |
|
|
quantization_mode=feat3d.quantization_mode |
|
|
) |
|
|
|
|
|
return transformed_feat3d |
|
|
|
|
|
|
|
|
def fuse_sparse_tensors(tensor1: Me.SparseTensor, tensor2: Me.SparseTensor) -> Me.SparseTensor: |
|
|
""" |
|
|
Efficiently fuse two sparse tensors |
|
|
Args: |
|
|
tensor1 (Me.SparseTensor): First sparse tensor |
|
|
tensor2 (Me.SparseTensor): Second sparse tensor |
|
|
|
|
|
Returns: |
|
|
Me.SparseTensor: Fused sparse tensor with concatenated features |
|
|
""" |
|
|
device = tensor1.device |
|
|
dtype = tensor1.F.dtype |
|
|
|
|
|
|
|
|
coords1, feats1 = tensor1.C, tensor1.F |
|
|
coords2, feats2 = tensor2.C, tensor2.F |
|
|
|
|
|
feat_dim1, feat_dim2 = feats1.shape[1], feats2.shape[1] |
|
|
fused_feat_dim = feat_dim1 + feat_dim2 |
|
|
|
|
|
|
|
|
all_coords = torch.cat([coords1, coords2], dim=0) |
|
|
n_coords1 = coords1.shape[0] |
|
|
|
|
|
|
|
|
coord_view = all_coords.view(all_coords.shape[0], -1) |
|
|
|
|
|
|
|
|
unique_coord_view, inverse_indices = torch.unique(coord_view, dim=0, return_inverse=True) |
|
|
unique_coords = unique_coord_view.view(-1, coords1.shape[1]) |
|
|
n_unique = unique_coords.shape[0] |
|
|
|
|
|
|
|
|
inv_indices_1 = inverse_indices[:n_coords1] |
|
|
inv_indices_2 = inverse_indices[n_coords1:] |
|
|
|
|
|
|
|
|
fused_features = torch.zeros(n_unique, fused_feat_dim, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
fused_features[inv_indices_1, :feat_dim1] = feats1 |
|
|
|
|
|
|
|
|
fused_features[inv_indices_2, feat_dim1:] = feats2 |
|
|
fused_tensor = Me.SparseTensor( |
|
|
features=fused_features, |
|
|
coordinates=unique_coords.int(), |
|
|
tensor_stride=tensor1.tensor_stride, |
|
|
quantization_mode=tensor1.quantization_mode |
|
|
) |
|
|
return fused_tensor |
|
|
|
|
|
|
|
|
def generate_multiscale_feat3d(transformed_feat3d: Me.SparseTensor) -> List[Me.SparseTensor]: |
|
|
""" |
|
|
Generate multi-scale sparse 3D features |
|
|
from transformed_feat3d to match sparse_multi_scale_features structure. |
|
|
Args: |
|
|
transformed_feat3d (Me.SparseTensor): |
|
|
Input sparse tensor from occupancy-aware lifting (256 grid) |
|
|
|
|
|
Returns: |
|
|
List[Me.SparseTensor]: Multi-scale sparse tensors |
|
|
at scales [1/2, 1/4, 1/8] corresponding to [128, 64, 32] grid sizes |
|
|
""" |
|
|
device = transformed_feat3d.device |
|
|
|
|
|
|
|
|
|
|
|
pooling_op = Me.MinkowskiMaxPooling( |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
dimension=3 |
|
|
).to(device) |
|
|
|
|
|
multi_scale_feat3d = [] |
|
|
current_tensor = transformed_feat3d |
|
|
target_strides = [2, 4, 8] |
|
|
|
|
|
|
|
|
for _, target_stride in enumerate(target_strides): |
|
|
|
|
|
pooled_tensor = pooling_op(current_tensor) |
|
|
|
|
|
|
|
|
|
|
|
if pooled_tensor.tensor_stride != target_stride: |
|
|
pooled_tensor = Me.SparseTensor( |
|
|
features=pooled_tensor.F, |
|
|
coordinates=pooled_tensor.C, |
|
|
tensor_stride=target_stride, |
|
|
quantization_mode=pooled_tensor.quantization_mode |
|
|
) |
|
|
|
|
|
multi_scale_feat3d.append(pooled_tensor) |
|
|
|
|
|
|
|
|
|
|
|
current_tensor = pooled_tensor |
|
|
|
|
|
return multi_scale_feat3d |
|
|
|