Spaces:
Running
on
L40S
Running
on
L40S
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # post process function for all heads: extract 3D points/confidence from output | |
| # -------------------------------------------------------- | |
| import torch | |
| def postprocess(out, depth_mode, conf_mode): | |
| """ | |
| extract 3D points/confidence from prediction head output | |
| """ | |
| fmap = out.permute(0, 2, 3, 1) # B,H,W,3 | |
| res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) | |
| if conf_mode is not None: | |
| res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) | |
| return res | |
| def reg_dense_depth(xyz, mode): | |
| """ | |
| extract 3D points from prediction head output | |
| """ | |
| mode, vmin, vmax = mode | |
| no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) | |
| assert no_bounds | |
| if mode == 'linear': | |
| if no_bounds: | |
| return xyz # [-inf, +inf] | |
| return xyz.clip(min=vmin, max=vmax) | |
| # distance to origin | |
| d = xyz.norm(dim=-1, keepdim=True) | |
| xyz = xyz / d.clip(min=1e-8) | |
| if mode == 'square': | |
| return xyz * d.square() | |
| if mode == 'exp': | |
| return xyz * torch.expm1(d) | |
| raise ValueError(f'bad {mode=}') | |
| def reg_dense_conf(x, mode): | |
| """ | |
| extract confidence from prediction head output | |
| """ | |
| mode, vmin, vmax = mode | |
| if mode == 'exp': | |
| return vmin + x.exp().clip(max=vmax-vmin) | |
| if mode == 'sigmoid': | |
| return (vmax - vmin) * torch.sigmoid(x) + vmin | |
| raise ValueError(f'bad {mode=}') | |