Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 2 files
Browse files- diffusionsfm/inference/ddim.py +181 -0
 - diffusionsfm/inference/predict.py +96 -0
 
    	
        diffusionsfm/inference/ddim.py
    ADDED
    
    | 
         @@ -0,0 +1,181 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from tqdm.auto import tqdm
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from diffusionsfm.utils.rays import compute_ndc_coordinates
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def inference_ddim(
         
     | 
| 10 | 
         
            +
                model,
         
     | 
| 11 | 
         
            +
                images,
         
     | 
| 12 | 
         
            +
                device,
         
     | 
| 13 | 
         
            +
                crop_parameters=None,
         
     | 
| 14 | 
         
            +
                eta=0,
         
     | 
| 15 | 
         
            +
                num_inference_steps=100,
         
     | 
| 16 | 
         
            +
                pbar=True,
         
     | 
| 17 | 
         
            +
                stop_iteration=None,
         
     | 
| 18 | 
         
            +
                num_patches_x=16,
         
     | 
| 19 | 
         
            +
                num_patches_y=16,
         
     | 
| 20 | 
         
            +
                visualize=False,
         
     | 
| 21 | 
         
            +
                max_num_images=8,
         
     | 
| 22 | 
         
            +
                seed=0,
         
     | 
| 23 | 
         
            +
            ):
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
                Implements DDIM-style inference.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                To get multiple samples, batch the images multiple times.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                Args:
         
     | 
| 30 | 
         
            +
                    model: Ray Diffuser.
         
     | 
| 31 | 
         
            +
                    images (torch.Tensor): (B, N, C, H, W).
         
     | 
| 32 | 
         
            +
                    patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
         
     | 
| 33 | 
         
            +
                        truth (B, N, P, 6).
         
     | 
| 34 | 
         
            +
                    eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
         
     | 
| 35 | 
         
            +
                        1 is equivalent to DDPM. (Default: 0)
         
     | 
| 36 | 
         
            +
                    num_inference_steps (int, optional): Number of inference steps. (Default: 100)
         
     | 
| 37 | 
         
            +
                    pbar (bool, optional): Whether to show progress bar. (Default: True)
         
     | 
| 38 | 
         
            +
                """
         
     | 
| 39 | 
         
            +
                timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
         
     | 
| 40 | 
         
            +
                batch_size = images.shape[0]
         
     | 
| 41 | 
         
            +
                num_images = images.shape[1]
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                if isinstance(eta, list):
         
     | 
| 44 | 
         
            +
                    eta_0, eta_1 = float(eta[0]), float(eta[1])
         
     | 
| 45 | 
         
            +
                else:
         
     | 
| 46 | 
         
            +
                    eta_0, eta_1 = 0, 0
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                # Fixing seed
         
     | 
| 49 | 
         
            +
                if seed is not None:
         
     | 
| 50 | 
         
            +
                    torch.manual_seed(seed)
         
     | 
| 51 | 
         
            +
                    random.seed(seed)
         
     | 
| 52 | 
         
            +
                    np.random.seed(seed)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                with torch.no_grad():
         
     | 
| 55 | 
         
            +
                    x_tau = torch.randn(
         
     | 
| 56 | 
         
            +
                        batch_size,
         
     | 
| 57 | 
         
            +
                        num_images,
         
     | 
| 58 | 
         
            +
                        model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
         
     | 
| 59 | 
         
            +
                        num_patches_x,
         
     | 
| 60 | 
         
            +
                        num_patches_y,
         
     | 
| 61 | 
         
            +
                        device=device,
         
     | 
| 62 | 
         
            +
                    )
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    if visualize:
         
     | 
| 65 | 
         
            +
                        x_taus = [x_tau]
         
     | 
| 66 | 
         
            +
                        all_pred = []
         
     | 
| 67 | 
         
            +
                        noise_samples = []
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    image_features = model.feature_extractor(images, autoresize=True)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    if model.append_ndc:
         
     | 
| 72 | 
         
            +
                        ndc_coordinates = compute_ndc_coordinates(
         
     | 
| 73 | 
         
            +
                            crop_parameters=crop_parameters,
         
     | 
| 74 | 
         
            +
                            no_crop_param_device="cpu",
         
     | 
| 75 | 
         
            +
                            num_patches_x=model.width,
         
     | 
| 76 | 
         
            +
                            num_patches_y=model.width,
         
     | 
| 77 | 
         
            +
                            distortion_coeffs=None,
         
     | 
| 78 | 
         
            +
                        )[..., :2].to(device)
         
     | 
| 79 | 
         
            +
                        ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
         
     | 
| 80 | 
         
            +
                    else:
         
     | 
| 81 | 
         
            +
                        ndc_coordinates = None
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    if stop_iteration is None:
         
     | 
| 84 | 
         
            +
                        loop = range(len(timesteps))
         
     | 
| 85 | 
         
            +
                    else:
         
     | 
| 86 | 
         
            +
                        loop = range(len(timesteps) - stop_iteration + 1)
         
     | 
| 87 | 
         
            +
                    loop = tqdm(loop) if pbar else loop
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    for t in loop:
         
     | 
| 90 | 
         
            +
                        tau = timesteps[t]
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                        if tau > 0 and eta_1 > 0:
         
     | 
| 93 | 
         
            +
                            z = torch.randn(
         
     | 
| 94 | 
         
            +
                                batch_size,
         
     | 
| 95 | 
         
            +
                                num_images,
         
     | 
| 96 | 
         
            +
                                model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
         
     | 
| 97 | 
         
            +
                                num_patches_x,
         
     | 
| 98 | 
         
            +
                                num_patches_y,
         
     | 
| 99 | 
         
            +
                                device=device,
         
     | 
| 100 | 
         
            +
                            )
         
     | 
| 101 | 
         
            +
                        else:
         
     | 
| 102 | 
         
            +
                            z = 0
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                        alpha = model.noise_scheduler.alphas_cumprod[tau]
         
     | 
| 105 | 
         
            +
                        if tau > 0:
         
     | 
| 106 | 
         
            +
                            tau_prev = timesteps[t + 1]
         
     | 
| 107 | 
         
            +
                            alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
         
     | 
| 108 | 
         
            +
                        else:
         
     | 
| 109 | 
         
            +
                            alpha_prev = torch.tensor(1.0, device=device).float()
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        sigma_t = (
         
     | 
| 112 | 
         
            +
                            torch.sqrt((1 - alpha_prev) / (1 - alpha))
         
     | 
| 113 | 
         
            +
                            * torch.sqrt(1 - alpha / alpha_prev)
         
     | 
| 114 | 
         
            +
                        )
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                        if num_images > max_num_images:
         
     | 
| 117 | 
         
            +
                            eps_pred = torch.zeros_like(x_tau)
         
     | 
| 118 | 
         
            +
                            noise_sample = torch.zeros_like(x_tau)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                            # Randomly split image indices (excluding index 0), then prepend 0 to each split
         
     | 
| 121 | 
         
            +
                            indices_split = torch.split(
         
     | 
| 122 | 
         
            +
                                torch.randperm(num_images - 1) + 1, max_num_images - 1
         
     | 
| 123 | 
         
            +
                            )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                            for indices in indices_split:
         
     | 
| 126 | 
         
            +
                                indices = torch.cat((torch.tensor([0]), indices))  # Ensure index 0 is always included
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                                eps_pred_ind, noise_sample_ind = model(
         
     | 
| 129 | 
         
            +
                                    features=image_features[:, indices],
         
     | 
| 130 | 
         
            +
                                    rays_noisy=x_tau[:, indices],
         
     | 
| 131 | 
         
            +
                                    t=int(tau),
         
     | 
| 132 | 
         
            +
                                    ndc_coordinates=ndc_coordinates[:, indices],
         
     | 
| 133 | 
         
            +
                                    indices=indices,
         
     | 
| 134 | 
         
            +
                                )
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                                eps_pred[:, indices] += eps_pred_ind
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                                if noise_sample_ind is not None:
         
     | 
| 139 | 
         
            +
                                    noise_sample[:, indices] += noise_sample_ind
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                            # Average over splits for the shared reference index (0)
         
     | 
| 142 | 
         
            +
                            eps_pred[:, 0] /= len(indices_split)
         
     | 
| 143 | 
         
            +
                            noise_sample[:, 0] /= len(indices_split)
         
     | 
| 144 | 
         
            +
                        else:
         
     | 
| 145 | 
         
            +
                            eps_pred, noise_sample = model(
         
     | 
| 146 | 
         
            +
                                features=image_features,
         
     | 
| 147 | 
         
            +
                                rays_noisy=x_tau,
         
     | 
| 148 | 
         
            +
                                t=int(tau),
         
     | 
| 149 | 
         
            +
                                ndc_coordinates=ndc_coordinates,
         
     | 
| 150 | 
         
            +
                            )
         
     | 
| 151 | 
         
            +
                            
         
     | 
| 152 | 
         
            +
                        if model.use_homogeneous:
         
     | 
| 153 | 
         
            +
                            p1 = eps_pred[:, :, :4]
         
     | 
| 154 | 
         
            +
                            p2 = eps_pred[:, :, 4:]
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                            c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
         
     | 
| 157 | 
         
            +
                            c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
         
     | 
| 158 | 
         
            +
                            eps_pred[:, :, :4] = p1 / c1
         
     | 
| 159 | 
         
            +
                            eps_pred[:, :, 4:] = p2 / c2
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                        if visualize:
         
     | 
| 162 | 
         
            +
                            all_pred.append(eps_pred.clone())
         
     | 
| 163 | 
         
            +
                            noise_samples.append(noise_sample)
         
     | 
| 164 | 
         
            +
                            
         
     | 
| 165 | 
         
            +
                        # TODO: Can simplify this a lot
         
     | 
| 166 | 
         
            +
                        x0_pred = eps_pred.clone()
         
     | 
| 167 | 
         
            +
                        eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
         
     | 
| 168 | 
         
            +
                            1 - alpha
         
     | 
| 169 | 
         
            +
                        )
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                        dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
         
     | 
| 172 | 
         
            +
                        noise = eta_1 * sigma_t * z
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                        new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
         
     | 
| 175 | 
         
            +
                        x_tau = new_x_tau
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                        if visualize:
         
     | 
| 178 | 
         
            +
                            x_taus.append(x_tau.detach().clone())
         
     | 
| 179 | 
         
            +
                if visualize:
         
     | 
| 180 | 
         
            +
                    return x_tau, x_taus, all_pred, noise_samples
         
     | 
| 181 | 
         
            +
                return x_tau
         
     | 
    	
        diffusionsfm/inference/predict.py
    ADDED
    
    | 
         @@ -0,0 +1,96 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from diffusionsfm.inference.ddim import inference_ddim
         
     | 
| 2 | 
         
            +
            from diffusionsfm.utils.rays import (
         
     | 
| 3 | 
         
            +
                Rays,
         
     | 
| 4 | 
         
            +
                rays_to_cameras,
         
     | 
| 5 | 
         
            +
                rays_to_cameras_homography,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def predict_cameras(
         
     | 
| 10 | 
         
            +
                model,
         
     | 
| 11 | 
         
            +
                images,
         
     | 
| 12 | 
         
            +
                device,
         
     | 
| 13 | 
         
            +
                crop_parameters=None,
         
     | 
| 14 | 
         
            +
                stop_iteration=None,
         
     | 
| 15 | 
         
            +
                num_patches_x=16,
         
     | 
| 16 | 
         
            +
                num_patches_y=16,
         
     | 
| 17 | 
         
            +
                additional_timesteps=(),
         
     | 
| 18 | 
         
            +
                calculate_intrinsics=False,
         
     | 
| 19 | 
         
            +
                max_num_images=8,
         
     | 
| 20 | 
         
            +
                mode=None,
         
     | 
| 21 | 
         
            +
                return_rays=False,
         
     | 
| 22 | 
         
            +
                use_homogeneous=False,
         
     | 
| 23 | 
         
            +
                seed=0,
         
     | 
| 24 | 
         
            +
            ):
         
     | 
| 25 | 
         
            +
                """
         
     | 
| 26 | 
         
            +
                Args:
         
     | 
| 27 | 
         
            +
                    images (torch.Tensor): (N, C, H, W)
         
     | 
| 28 | 
         
            +
                    crop_parameters (torch.Tensor): (N, 4) or None
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                if calculate_intrinsics:
         
     | 
| 31 | 
         
            +
                    ray_to_cam = rays_to_cameras_homography
         
     | 
| 32 | 
         
            +
                else:
         
     | 
| 33 | 
         
            +
                    ray_to_cam = rays_to_cameras
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                get_spatial_rays = Rays.from_spatial
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
         
     | 
| 38 | 
         
            +
                    model,
         
     | 
| 39 | 
         
            +
                    images.unsqueeze(0),
         
     | 
| 40 | 
         
            +
                    device,
         
     | 
| 41 | 
         
            +
                    crop_parameters=crop_parameters.unsqueeze(0),
         
     | 
| 42 | 
         
            +
                    pbar=False,
         
     | 
| 43 | 
         
            +
                    stop_iteration=stop_iteration,
         
     | 
| 44 | 
         
            +
                    eta=[1, 0],
         
     | 
| 45 | 
         
            +
                    num_inference_steps=100,
         
     | 
| 46 | 
         
            +
                    num_patches_x=num_patches_x,
         
     | 
| 47 | 
         
            +
                    num_patches_y=num_patches_y,
         
     | 
| 48 | 
         
            +
                    visualize=True,
         
     | 
| 49 | 
         
            +
                    max_num_images=max_num_images,
         
     | 
| 50 | 
         
            +
                )
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                spatial_rays = get_spatial_rays(
         
     | 
| 53 | 
         
            +
                    rays_final[0],
         
     | 
| 54 | 
         
            +
                    mode=mode,
         
     | 
| 55 | 
         
            +
                    num_patches_x=num_patches_x,
         
     | 
| 56 | 
         
            +
                    num_patches_y=num_patches_y,
         
     | 
| 57 | 
         
            +
                    use_homogeneous=use_homogeneous,
         
     | 
| 58 | 
         
            +
                )
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                pred_cam = ray_to_cam(
         
     | 
| 61 | 
         
            +
                    spatial_rays,
         
     | 
| 62 | 
         
            +
                    crop_parameters,
         
     | 
| 63 | 
         
            +
                    num_patches_x=num_patches_x,
         
     | 
| 64 | 
         
            +
                    num_patches_y=num_patches_y,
         
     | 
| 65 | 
         
            +
                    depth_resolution=model.depth_resolution,
         
     | 
| 66 | 
         
            +
                    average_centers=True,
         
     | 
| 67 | 
         
            +
                    directions_from_averaged_center=True,
         
     | 
| 68 | 
         
            +
                )
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                additional_predictions = []
         
     | 
| 71 | 
         
            +
                for t in additional_timesteps:
         
     | 
| 72 | 
         
            +
                    ray = pred_intermediate[t]
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    ray = get_spatial_rays(
         
     | 
| 75 | 
         
            +
                        ray[0],
         
     | 
| 76 | 
         
            +
                        mode=mode,
         
     | 
| 77 | 
         
            +
                        num_patches_x=num_patches_x,
         
     | 
| 78 | 
         
            +
                        num_patches_y=num_patches_y,
         
     | 
| 79 | 
         
            +
                        use_homogeneous=use_homogeneous,
         
     | 
| 80 | 
         
            +
                    )
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    cam = ray_to_cam(
         
     | 
| 83 | 
         
            +
                        ray,
         
     | 
| 84 | 
         
            +
                        crop_parameters,
         
     | 
| 85 | 
         
            +
                        num_patches_x=num_patches_x,
         
     | 
| 86 | 
         
            +
                        num_patches_y=num_patches_y,
         
     | 
| 87 | 
         
            +
                        average_centers=True,
         
     | 
| 88 | 
         
            +
                        directions_from_averaged_center=True,
         
     | 
| 89 | 
         
            +
                    )
         
     | 
| 90 | 
         
            +
                    if return_rays:
         
     | 
| 91 | 
         
            +
                        cam = (cam, ray)
         
     | 
| 92 | 
         
            +
                    additional_predictions.append(cam)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                if return_rays:
         
     | 
| 95 | 
         
            +
                    return (pred_cam, spatial_rays), additional_predictions
         
     | 
| 96 | 
         
            +
                return pred_cam, additional_predictions, spatial_rays
         
     |