File size: 5,240 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# --------------------------------------------------------
# Dummy optimizer for visualizing pairs
# --------------------------------------------------------
import numpy as np
import torch
import torch.nn as nn
import cv2
from dust3r.cloud_opt_flow.base_opt import BasePCOptimizer
from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
from dust3r.cloud_opt_flow.commons import edge_str
from dust3r.post_process import estimate_focal_knowing_depth
class PairViewer (BasePCOptimizer):
This a Dummy Optimizer.
To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.is_symmetrized and self.n_edges == 2
self.has_im_poses = True
# compute all parameters directly from raw input
self.focals = []
self.pp = []
rel_poses = []
confs = []
for i in range(self.n_imgs):
conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
if self.verbose:
print(f' - {conf=:.3} for edge {i}-{1-i}')
H, W = self.imshapes[i]
pts3d = self.pred_i[edge_str(i, 1-i)]
pp = torch.tensor((W/2, H/2))
focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
# estimate the pose of pts1 in image 2
pixels = np.mgrid[:W, :H].T.astype(np.float32)
pts3d = self.pred_j[edge_str(1-i, i)].numpy()
assert pts3d.shape[:2] == (H, W)
msk = self.get_masks()[i].numpy()
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
success, R, T, inliers = res
assert success
R = cv2.Rodrigues(R)[0] # world to cam
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
pose = np.eye(4)
# let's use the pair with the most confidence
if confs[0] > confs[1]:
# ptcloud is expressed in camera1
self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
# ptcloud is expressed in camera2
self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
self.depth = nn.ParameterList(self.depth)
for p in self.parameters():
p.requires_grad = False
def _set_depthmap(self, idx, depth, force=False):
if self.verbose:
print('_set_depthmap is ignored in PairViewer')
def get_depthmaps(self, raw=False):
depth = [ for d in self.depth]
return depth
def _set_focal(self, idx, focal, force=False):
self.focals[idx] = focal
def get_focals(self):
return self.focals
def get_known_focal_mask(self):
return torch.tensor([not (p.requires_grad) for p in self.focals])
def get_principal_points(self):
return self.pp
def get_intrinsics(self):
focals = self.get_focals()
pps = self.get_principal_points()
K = torch.zeros((len(focals), 3, 3), device=self.device)
for i in range(len(focals)):
K[i, 0, 0] = K[i, 1, 1] = focals[i]
K[i, :2, 2] = pps[i]
K[i, 2, 2] = 1
return K
def get_im_poses(self):
return self.im_poses
def depth_to_pts3d(self, raw_pts=False):
pts3d = []
if raw_pts:
im_poses = self.get_im_poses()
if im_poses[0].sum() == 4:
for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
return pts3d
def forward(self):
return float('nan')