File size: 2,960 Bytes
82b898c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3717f04
82b898c
 
 
 
 
 
 
 
 
 
 
3717f04
82b898c
 
 
 
3717f04
82b898c
 
 
3717f04
82b898c
 
 
 
 
 
 
 
 
 
3717f04
82b898c
3717f04
82b898c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3717f04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
import cv2
import open3d as o3d
from dust3r.post_process import estimate_focal_knowing_depth
from dust3r.utils.geometry import inv

def estimate_focal(pts3d_i, pp=None):
    if pp is None:
        H, W, THREE = pts3d_i.shape
        assert THREE == 3
        pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
    focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
    return float(focal)

def pixel_grid(H, W):
    return np.mgrid[:W, :H].T.astype(np.float32)

def sRT_to_4x4(scale, R, T, device):
    trf = torch.eye(4, device=device)
    trf[:3, :3] = R * scale
    trf[:3, 3] = T.ravel()  # doesn't need scaling
    return trf

def to_numpy(tensor):
    return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor

def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
    # extract camera poses and focals with RANSAC-PnP
    if msk.sum() < 4:
        return None  # we need at least 4 points for PnP
    pts3d, msk = map(to_numpy, (pts3d, msk))

    H, W, THREE = pts3d.shape
    assert THREE == 3
    pixels = pixel_grid(H, W)

    if focal is None:
        S = max(W, H)
        tentative_focals = np.geomspace(S/2, S*3, 21)
    else:
        tentative_focals = [focal]

    if pp is None:
        pp = (W/2, H/2)
    else:
        pp = to_numpy(pp)

    best = 0,
    for focal in tentative_focals:
        K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])

        success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
                                                    iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)

        if not success:
            continue

        score = len(inliers)
        if success and score > best[0]:
            best = score, R, T, focal

    if not best[0]:
        return None

    _, R, T, best_focal = best
    R = cv2.Rodrigues(R)[0]  # world to cam
    R, T = map(torch.from_numpy, (R, T))

    return best_focal, inv(sRT_to_4x4(1, R, T, device))  # cam to world

def solve_cemara(pts3d, msk, device, focal=None, pp=None):
    # Estimate focal length
    if focal is None:
        focal = estimate_focal(pts3d, pp)
    
    # Compute camera pose using PnP
    result = fast_pnp(pts3d, focal, msk, device, pp)
    
    if result is None:
        return None, focal
    
    best_focal, camera_to_world = result
    
    # Construct K matrix
    H, W, _ = pts3d.shape
    if pp is None:
        pp = (W/2, H/2)
    
    camera_parameters = o3d.camera.PinholeCameraParameters()
    intrinsic = o3d.camera.PinholeCameraIntrinsic()
    intrinsic.set_intrinsics(W, H, 
                             best_focal, best_focal, 
                             pp[0], pp[1])

    camera_parameters.intrinsic = intrinsic
    camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()

    return camera_parameters, best_focal