Spaces:
Configuration error
Configuration error
import torch.nn.functional as F | |
import torch | |
from lib.config import cfg | |
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False): | |
"""Transforms model's predictions to semantically meaningful values. | |
Args: | |
raw: [num_rays, num_samples along ray, 4]. Prediction from model. | |
z_vals: [num_rays, num_samples along ray]. Integration time. | |
rays_d: [num_rays, 3]. Direction of each ray. | |
Returns: | |
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. | |
disp_map: [num_rays]. Disparity map. Inverse of depth map. | |
acc_map: [num_rays]. Sum of weights along each ray. | |
weights: [num_rays, num_samples]. Weights assigned to each sampled color. | |
depth_map: [num_rays]. Estimated distance to object. | |
""" | |
raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * | |
dists) | |
dists = z_vals[..., 1:] - z_vals[..., :-1] | |
dists = torch.cat( | |
[dists, | |
torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], | |
-1) # [N_rays, N_samples] | |
dists = dists * torch.norm(rays_d[..., None, :], dim=-1) | |
rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] | |
noise = 0. | |
if raw_noise_std > 0.: | |
noise = torch.randn(raw[..., 3].shape) * raw_noise_std | |
alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] | |
# weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) | |
weights = alpha * torch.cumprod( | |
torch.cat( | |
[torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], | |
-1), -1)[:, :-1] | |
rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] | |
depth_map = torch.sum(weights * z_vals, -1) | |
disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), | |
depth_map / torch.sum(weights, -1)) | |
acc_map = torch.sum(weights, -1) | |
if white_bkgd: | |
rgb_map = rgb_map + (1. - acc_map[..., None]) | |
return rgb_map, disp_map, acc_map, weights, depth_map | |
# Hierarchical sampling (section 5.2) | |
def sample_pdf(bins, weights, N_samples, det=False): | |
from torchsearchsorted import searchsorted | |
# Get pdf | |
weights = weights + 1e-5 # prevent nans | |
pdf = weights / torch.sum(weights, -1, keepdim=True) | |
cdf = torch.cumsum(pdf, -1) | |
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], | |
-1) # (batch, len(bins)) | |
# Take uniform samples | |
if det: | |
u = torch.linspace(0., 1., steps=N_samples).to(cdf) | |
u = u.expand(list(cdf.shape[:-1]) + [N_samples]) | |
else: | |
u = torch.rand(list(cdf.shape[:-1]) + [N_samples]).to(cdf) | |
# Invert CDF | |
u = u.contiguous() | |
inds = searchsorted(cdf, u, side='right') | |
below = torch.max(torch.zeros_like(inds - 1), inds - 1) | |
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) | |
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) | |
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) | |
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) | |
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] | |
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) | |
denom = (cdf_g[..., 1] - cdf_g[..., 0]) | |
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) | |
t = (u - cdf_g[..., 0]) / denom | |
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) | |
return samples | |