SparseBev / models /csrc /wrapper.py
chinmaygarde's picture
Attempt an export to ONNX.
3ea6165 unverified
import torch
import torch.nn.functional as F
try:
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
MSMV_CUDA = True
except ImportError as e:
print('Warning: failed to load one or more CUDA extensions, performance may be hurt.')
print('Error message:', e)
MSMV_CUDA = False
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
"""
value: [B, N, H1W1 + H2W2..., C]
sampling_locations: [B, Q, P, 3]
scale_weights: [B, Q, P, 4]
"""
assert scale_weights.shape[-1] == len(mlvl_feats)
B, C, _, _, _ = mlvl_feats[0].shape
_, Q, P, _ = sampling_locations.shape
sampling_locations = sampling_locations * 2 - 1
sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3]
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
for lvl, feat in enumerate(mlvl_feats):
out = F.grid_sample(
feat, sampling_locations, mode='bilinear',
padding_mode='zeros', align_corners=True,
)[..., 0] # [B, C, Q, P]
out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)
final += out
return final.permute(0, 2, 1, 3)
class MSMVSamplingC2345(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c2345_forward)
return _ms_deform_attn_cuda_c2345_forward(
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c2345_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
class MSMVSamplingC23456(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c23456_forward)
return _ms_deform_attn_cuda_c23456_forward(
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c23456_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
if len(mlvl_feats) == 4 and MSMV_CUDA:
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
elif len(mlvl_feats) == 5 and MSMV_CUDA:
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
else:
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
def msmv_sampling_onnx(mlvl_feats, uv, view_idx, scale_weights):
"""
ONNX-compatible multi-scale multi-view sampling using 4D F.grid_sample.
Replaces the 5D volumetric grid_sample used in msmv_sampling_pytorch with
separate per-view 4D grid_samples followed by a torch.gather for view
selection. All ops are in ONNX opset 16.
Args:
mlvl_feats: list of [BTG, C, N, H, W] channel-first feature maps
uv: [BTG, Q, P, 2] normalised (u, v) in [0, 1]
view_idx: [BTG, Q, P] integer camera-view indices
scale_weights:[BTG, Q, P, L] softmax weights over pyramid levels
Returns:
[BTG, Q, C, P]
"""
BTG, C, N, _, _ = mlvl_feats[0].shape
_, Q, P, _ = uv.shape
# Convert UV from [0, 1] to [-1, 1] for F.grid_sample
uv_gs = uv * 2.0 - 1.0 # [BTG, Q, P, 2]
# Tile UV for all N views: [BTG*N, Q, P, 2]
# Use expand+contiguous+reshape (maps to ONNX Expand, better CoreML EP support
# than repeat_interleave which maps to ONNX Tile and can trip up CoreML)
uv_gs = uv_gs.unsqueeze(1).expand(BTG, N, Q, P, 2).contiguous().reshape(BTG * N, Q, P, 2)
# Pre-expand view_idx for gathering along the N dim: [BTG, C, 1, Q, P]
view_idx_g = view_idx[:, None, None, :, :].expand(BTG, C, 1, Q, P)
final = torch.zeros(BTG, C, Q, P, device=mlvl_feats[0].device, dtype=mlvl_feats[0].dtype)
for lvl, feat in enumerate(mlvl_feats):
_, _, _, H_lvl, W_lvl = feat.shape
# [BTG, C, N, H, W] -> [BTG, N, C, H, W] -> [BTG*N, C, H, W]
feat_4d = feat.permute(0, 2, 1, 3, 4).reshape(BTG * N, C, H_lvl, W_lvl)
# 4D grid_sample: [BTG*N, C, Q, P]
sampled = F.grid_sample(feat_4d, uv_gs, mode='bilinear', padding_mode='zeros', align_corners=True)
# [BTG*N, C, Q, P] -> [BTG, N, C, Q, P] -> [BTG, C, N, Q, P]
sampled = sampled.reshape(BTG, N, C, Q, P).permute(0, 2, 1, 3, 4)
# Gather the selected camera view: [BTG, C, 1, Q, P] -> [BTG, C, Q, P]
sampled = torch.gather(sampled, 2, view_idx_g).squeeze(2)
# Accumulate with per-level scale weight
w = scale_weights[..., lvl].reshape(BTG, 1, Q, P)
final = final + sampled * w
return final.permute(0, 2, 1, 3) # [BTG, Q, C, P]