File size: 2,493 Bytes
5e82535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from diffusionsfm.inference.ddim import inference_ddim
from diffusionsfm.utils.rays import (
    Rays,
    rays_to_cameras,
    rays_to_cameras_homography,
)


def predict_cameras(
    model,
    images,
    device,
    crop_parameters=None,
    stop_iteration=None,
    num_patches_x=16,
    num_patches_y=16,
    additional_timesteps=(),
    calculate_intrinsics=False,
    max_num_images=8,
    mode=None,
    return_rays=False,
    use_homogeneous=False,
    seed=0,
):
    """
    Args:
        images (torch.Tensor): (N, C, H, W)
        crop_parameters (torch.Tensor): (N, 4) or None
    """
    if calculate_intrinsics:
        ray_to_cam = rays_to_cameras_homography
    else:
        ray_to_cam = rays_to_cameras

    get_spatial_rays = Rays.from_spatial

    rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
        model,
        images.unsqueeze(0),
        device,
        crop_parameters=crop_parameters.unsqueeze(0),
        pbar=False,
        stop_iteration=stop_iteration,
        eta=[1, 0],
        num_inference_steps=100,
        num_patches_x=num_patches_x,
        num_patches_y=num_patches_y,
        visualize=True,
        max_num_images=max_num_images,
    )

    spatial_rays = get_spatial_rays(
        rays_final[0],
        mode=mode,
        num_patches_x=num_patches_x,
        num_patches_y=num_patches_y,
        use_homogeneous=use_homogeneous,
    )

    pred_cam = ray_to_cam(
        spatial_rays,
        crop_parameters,
        num_patches_x=num_patches_x,
        num_patches_y=num_patches_y,
        depth_resolution=model.depth_resolution,
        average_centers=True,
        directions_from_averaged_center=True,
    )

    additional_predictions = []
    for t in additional_timesteps:
        ray = pred_intermediate[t]

        ray = get_spatial_rays(
            ray[0],
            mode=mode,
            num_patches_x=num_patches_x,
            num_patches_y=num_patches_y,
            use_homogeneous=use_homogeneous,
        )

        cam = ray_to_cam(
            ray,
            crop_parameters,
            num_patches_x=num_patches_x,
            num_patches_y=num_patches_y,
            average_centers=True,
            directions_from_averaged_center=True,
        )
        if return_rays:
            cam = (cam, ray)
        additional_predictions.append(cam)

    if return_rays:
        return (pred_cam, spatial_rays), additional_predictions
    return pred_cam, additional_predictions, spatial_rays