slam3r-i2p_demo / slam3r /inference.py
siyan824's picture
init
8bd45de
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utilities needed for the inference
# --------------------------------------------------------
import torch
import numpy as np
from .utils.misc import invalid_to_zeros
from .utils.geometry import geotrf, inv
def loss_of_one_batch(loss_func, batch, model, criterion, device,
use_amp=False, ret=None,
assist_model=None, train=False, epoch=0,
args=None):
if loss_func == "i2p":
return loss_of_one_batch_multiview(batch, model, criterion,
device, use_amp, ret,
args.ref_id)
elif loss_func == "i2p_corr_score":
return loss_of_one_batch_multiview_corr_score(batch, model, criterion,
device, use_amp, ret,
args.ref_id)
elif loss_func == "l2w":
return loss_of_one_batch_l2w(
batch, model, criterion,
device, use_amp, ret,
ref_ids=args.ref_ids, coord_frame_id=0,
exclude_ident=True, to_zero=True
)
else:
raise NotImplementedError
def loss_of_one_batch_multiview(batch, model, criterion, device,
use_amp=False, ret=None, ref_id=-1):
""" Function to compute the reconstruction loss of the Image-to-Points model
"""
views = batch
for view in views:
for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal
if name not in view:
continue
view[name] = view[name].to(device, non_blocking=True)
if ref_id == -1:
ref_id = (len(views)-1)//2
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
preds = model(views, ref_id=ref_id)
assert len(preds) == len(views)
with torch.cuda.amp.autocast(enabled=False):
if criterion is None:
loss = None
else:
loss = criterion(views, preds, ref_id=ref_id)
result = dict(views=views, preds=preds, loss=loss)
for i in range(len(preds)):
result[f'pred{i+1}'] = preds[i]
result[f'view{i+1}'] = views[i]
return result[ret] if ret else result
def loss_of_one_batch_multiview_corr_score(batch, model, criterion, device,
use_amp=False, ret=None, ref_id=-1):
views = batch
for view in views:
for name in 'img pts3d valid_mask camera_pose'.split(): # pseudo_focal
if name not in view:
continue
view[name] = view[name].to(device, non_blocking=True)
if ref_id == -1:
ref_id = (len(views)-1)//2
all_loss = [0, {}]
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
preds = model(views, ref_id=ref_id, return_corr_score=True)
assert len(preds) == len(views)
for i,pred in enumerate(preds):
if i == ref_id:
continue
patch_pseudo_conf = pred['pseudo_conf'] # (B,S)
true_conf = (pred['conf']-1.).mean(dim=(1,2)) # (B,) mean(exp(x))
pseudo_conf = torch.exp(patch_pseudo_conf).mean(dim=1) # (B,) mean(exp(batch(x)))
pseudo_conf = pseudo_conf / (1+pseudo_conf)
true_conf = true_conf / (1+true_conf)
dis = torch.abs(pseudo_conf-true_conf)
loss = dis.mean()
# if loss.isinf():
# print(((patch_pseudo_conf-patch_true_conf)**2).max())
all_loss[0] += loss
all_loss[1][f'pseudo_conf_loss_{i}'] = loss
result = dict(views=views, preds=preds, loss=all_loss)
for i in range(len(preds)):
result[f'pred{i+1}'] = preds[i]
result[f'view{i+1}'] = views[i]
return result[ret] if ret else result
def get_multiview_scale(pts:list, valid:list, norm_mode='avg_dis'):
# adpat from DUSt3R
for i in range(len(pts)):
assert pts[i].ndim >= 3 and pts[i].shape[-1] == 3
assert len(pts) == len(valid)
norm_mode, dis_mode = norm_mode.split('_')
if norm_mode == 'avg':
# gather all points together (joint normalization)
all_pts = []
all_nnz = 0
for i in range(len(pts)):
nan_pts, nnz = invalid_to_zeros(pts[i], valid[i], ndim=3)
# print(nnz,nan_pts.shape) #(B,) (B,H*W,3)
all_pts.append(nan_pts)
all_nnz += nnz
all_pts = torch.cat(all_pts, dim=1)
# compute distance to origin
all_dis = all_pts.norm(dim=-1)
if dis_mode == 'dis':
pass # do nothing
elif dis_mode == 'log1p':
all_dis = torch.log1p(all_dis)
else:
raise ValueError(f'bad {dis_mode=}')
norm_factor = all_dis.sum(dim=1) / (all_nnz + 1e-8)
else:
raise ValueError(f'bad {norm_mode=}')
norm_factor = norm_factor.clip(min=1e-8)
while norm_factor.ndim < pts[0].ndim:
norm_factor.unsqueeze_(-1)
# print('norm factor:', norm_factor)
return norm_factor
def loss_of_one_batch_l2w(batch, model, criterion, device,
use_amp=False, ret=None,
ref_ids=-1, coord_frame_id=0,
exclude_ident=True, to_zero=True):
""" Function to compute the reconstruction loss of the Local-to-World model
ref_ids: list of indices of the suppporting frames(excluding the coord_frame)
coord_frame_id: all the pointmaps input and output will be in the coord_frame_id's camera coordinate
exclude_ident: whether to exclude the coord_frame to simulate real-life inference scenarios
to_zero: whether to set the invalid points to zero
"""
views = batch
for view in views:
for name in 'img pts3d pts3d_cam valid_mask camera_pose'.split(): # pseudo_focal
if name not in view:
continue
view[name] = view[name].to(device, non_blocking=True)
if coord_frame_id == -1:
# ramdomly select a camera as the target camera
coord_frame_id = np.random.randint(0, len(views))
# print(coord_frame_id)
c2w = views[coord_frame_id]['camera_pose']
w2c = inv(c2w)
# exclude the frame that has the identity pose
if exclude_ident:
views.pop(coord_frame_id)
if ref_ids == -1:
ref_ids = [i for i in range(len(views)-1)] # all views except the last one
elif ref_ids == -2:
#select half of the views randomly
ref_ids = np.random.choice(len(views), len(views)//2, replace=False).tolist()
else:
assert isinstance(ref_ids, list)
for id in ref_ids:
views[id]['pts3d_world'] = geotrf(w2c, views[id]['pts3d']) #转移到目标坐标系
norm_factor_world = get_multiview_scale([views[id]['pts3d_world'] for id in ref_ids],
[views[id]['valid_mask'] for id in ref_ids],
norm_mode='avg_dis')
for id,view in enumerate(views):
if id in ref_ids:
view['pts3d_world'] = view['pts3d_world'].permute(0,3,1,2) / norm_factor_world
else:
norm_factor_src = get_multiview_scale([view['pts3d_cam']],
[view['valid_mask']],
norm_mode='avg_dis')
view['pts3d_cam'] = view['pts3d_cam'].permute(0,3,1,2) / norm_factor_src
if to_zero:
for id,view in enumerate(views):
valid_mask = view['valid_mask'].unsqueeze(1).float() # B,1,H,W
if id in ref_ids:
# print(view['pts3d_world'].shape, valid_mask.shape, (-valid_mask+1).sum())
view['pts3d_world'] = view['pts3d_world'] * valid_mask
else:
view['pts3d_cam'] = view['pts3d_cam'] * valid_mask
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
preds = model(views, ref_ids=ref_ids)
assert len(preds) == len(views)
with torch.cuda.amp.autocast(enabled=False):
if criterion is None:
loss = None
else:
loss = criterion(views, preds, ref_id=ref_ids, ref_camera=w2c, norm_scale=norm_factor_world)
result = dict(views=views, preds=preds, loss=loss)
for i in range(len(preds)):
result[f'pred{i+1}'] = preds[i]
result[f'view{i+1}'] = views[i]
return result[ret] if ret else result