EscherNet / dust3r /heads /postprocess.py
kxhit
update
5f093a6
raw
history blame
1.62 kB
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# post process function for all heads: extract 3D points/confidence from output
# --------------------------------------------------------
import torch
def postprocess(out, depth_mode, conf_mode):
"""
extract 3D points/confidence from prediction head output
"""
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
if conf_mode is not None:
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
return res
def reg_dense_depth(xyz, mode):
"""
extract 3D points from prediction head output
"""
mode, vmin, vmax = mode
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
assert no_bounds
if mode == 'linear':
if no_bounds:
return xyz # [-inf, +inf]
return xyz.clip(min=vmin, max=vmax)
# distance to origin
d = xyz.norm(dim=-1, keepdim=True)
xyz = xyz / d.clip(min=1e-8)
if mode == 'square':
return xyz * d.square()
if mode == 'exp':
return xyz * torch.expm1(d)
raise ValueError(f'bad {mode=}')
def reg_dense_conf(x, mode):
"""
extract confidence from prediction head output
"""
mode, vmin, vmax = mode
if mode == 'exp':
return vmin + x.exp().clip(max=vmax-vmin)
if mode == 'sigmoid':
return (vmax - vmin) * torch.sigmoid(x) + vmin
raise ValueError(f'bad {mode=}')