| import torch |
| import torch.nn.functional as F |
|
|
| try: |
| from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward |
| from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward |
| MSMV_CUDA = True |
| except ImportError as e: |
| print('Warning: failed to load one or more CUDA extensions, performance may be hurt.') |
| print('Error message:', e) |
| MSMV_CUDA = False |
|
|
|
|
| def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights): |
| """ |
| value: [B, N, H1W1 + H2W2..., C] |
| sampling_locations: [B, Q, P, 3] |
| scale_weights: [B, Q, P, 4] |
| """ |
| assert scale_weights.shape[-1] == len(mlvl_feats) |
|
|
| B, C, _, _, _ = mlvl_feats[0].shape |
| _, Q, P, _ = sampling_locations.shape |
|
|
| sampling_locations = sampling_locations * 2 - 1 |
| sampling_locations = sampling_locations[:, :, :, None, :] |
|
|
| final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device) |
|
|
| for lvl, feat in enumerate(mlvl_feats): |
| out = F.grid_sample( |
| feat, sampling_locations, mode='bilinear', |
| padding_mode='zeros', align_corners=True, |
| )[..., 0] |
| out = out * scale_weights[..., lvl].reshape(B, 1, Q, P) |
| final += out |
|
|
| return final.permute(0, 2, 1, 3) |
|
|
|
|
| class MSMVSamplingC2345(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights): |
| ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights) |
| |
| assert callable(_ms_deform_attn_cuda_c2345_forward) |
| return _ms_deform_attn_cuda_c2345_forward( |
| feat_c2, feat_c3, feat_c4, feat_c5, |
| sampling_locations, scale_weights) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors |
|
|
| assert callable(_ms_deform_attn_cuda_c2345_backward) |
| grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(), |
| feat_c2, feat_c3, feat_c4, feat_c5, |
| sampling_locations, scale_weights |
| ) |
| |
| return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight |
|
|
|
|
| class MSMVSamplingC23456(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights): |
| ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights) |
| |
| assert callable(_ms_deform_attn_cuda_c23456_forward) |
| return _ms_deform_attn_cuda_c23456_forward( |
| feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, |
| sampling_locations, scale_weights) |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors |
|
|
| assert callable(_ms_deform_attn_cuda_c23456_backward) |
| grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(), |
| feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, |
| sampling_locations, scale_weights |
| ) |
| |
| return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight |
|
|
|
|
| def msmv_sampling(mlvl_feats, sampling_locations, scale_weights): |
| if len(mlvl_feats) == 4 and MSMV_CUDA: |
| return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights) |
| elif len(mlvl_feats) == 5 and MSMV_CUDA: |
| return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights) |
| else: |
| return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights) |
|
|
|
|
| def msmv_sampling_onnx(mlvl_feats, uv, view_idx, scale_weights): |
| """ |
| ONNX-compatible multi-scale multi-view sampling using 4D F.grid_sample. |
| |
| Replaces the 5D volumetric grid_sample used in msmv_sampling_pytorch with |
| separate per-view 4D grid_samples followed by a torch.gather for view |
| selection. All ops are in ONNX opset 16. |
| |
| Args: |
| mlvl_feats: list of [BTG, C, N, H, W] channel-first feature maps |
| uv: [BTG, Q, P, 2] normalised (u, v) in [0, 1] |
| view_idx: [BTG, Q, P] integer camera-view indices |
| scale_weights:[BTG, Q, P, L] softmax weights over pyramid levels |
| Returns: |
| [BTG, Q, C, P] |
| """ |
| BTG, C, N, _, _ = mlvl_feats[0].shape |
| _, Q, P, _ = uv.shape |
|
|
| |
| uv_gs = uv * 2.0 - 1.0 |
|
|
| |
| |
| |
| uv_gs = uv_gs.unsqueeze(1).expand(BTG, N, Q, P, 2).contiguous().reshape(BTG * N, Q, P, 2) |
|
|
| |
| view_idx_g = view_idx[:, None, None, :, :].expand(BTG, C, 1, Q, P) |
|
|
| final = torch.zeros(BTG, C, Q, P, device=mlvl_feats[0].device, dtype=mlvl_feats[0].dtype) |
|
|
| for lvl, feat in enumerate(mlvl_feats): |
| _, _, _, H_lvl, W_lvl = feat.shape |
|
|
| |
| feat_4d = feat.permute(0, 2, 1, 3, 4).reshape(BTG * N, C, H_lvl, W_lvl) |
|
|
| |
| sampled = F.grid_sample(feat_4d, uv_gs, mode='bilinear', padding_mode='zeros', align_corners=True) |
|
|
| |
| sampled = sampled.reshape(BTG, N, C, Q, P).permute(0, 2, 1, 3, 4) |
|
|
| |
| sampled = torch.gather(sampled, 2, view_idx_g).squeeze(2) |
|
|
| |
| w = scale_weights[..., lvl].reshape(BTG, 1, Q, P) |
| final = final + sampled * w |
|
|
| return final.permute(0, 2, 1, 3) |
|
|