Spaces:
Running
on
Zero
Running
on
Zero
| # from internal import math | |
| import numpy as np | |
| import torch | |
| def searchsorted(a, v): | |
| """Find indices where v should be inserted into a to maintain order. | |
| Args: | |
| a: tensor, the sorted reference points that we are scanning to see where v | |
| should lie. | |
| v: tensor, the query points that we are pretending to insert into a. Does | |
| not need to be sorted. All but the last dimensions should match or expand | |
| to those of a, the last dimension can differ. | |
| Returns: | |
| (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the | |
| range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or | |
| last index of a. | |
| """ | |
| i = torch.arange(a.shape[-1], device=a.device) | |
| v_ge_a = v[..., None, :] >= a[..., :, None] | |
| idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values | |
| idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values | |
| return idx_lo, idx_hi | |
| def query(tq, t, y, outside_value=0): | |
| """Look up the values of the step function (t, y) at locations tq.""" | |
| idx_lo, idx_hi = searchsorted(t, tq) | |
| yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value), | |
| torch.take_along_dim(y, idx_lo, dim=-1)) | |
| return yq | |
| def inner_outer(t0, t1, y1): | |
| """Construct inner and outer measures on (t1, y1) for t0.""" | |
| cy1 = torch.cat([torch.zeros_like(y1[..., :1]), | |
| torch.cumsum(y1, dim=-1)], | |
| dim=-1) | |
| idx_lo, idx_hi = searchsorted(t1, t0) | |
| cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1) | |
| cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1) | |
| y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] | |
| y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:], | |
| cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:])) | |
| return y0_inner, y0_outer | |
| def lossfun_outer(t, w, t_env, w_env): | |
| """The proposal weight should be an upper envelope on the nerf weight.""" | |
| eps = torch.finfo(t.dtype).eps | |
| # eps = 1e-3 | |
| _, w_outer = inner_outer(t, t_env, w_env) | |
| # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's | |
| # more effective to pull w_outer up than it is to push w_inner down. | |
| # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0. | |
| return (w - w_outer).clamp_min(0) ** 2 / (w + eps) | |
| def weight_to_pdf(t, w): | |
| """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" | |
| eps = torch.finfo(t.dtype).eps | |
| return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps) | |
| def pdf_to_weight(t, p): | |
| """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" | |
| return p * (t[..., 1:] - t[..., :-1]) | |
| def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)): | |
| """Dilate (via max-pooling) a non-negative step function.""" | |
| t0 = t[..., :-1] - dilation | |
| t1 = t[..., 1:] + dilation | |
| t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1) | |
| t_dilate = torch.clip(t_dilate, *domain) | |
| w_dilate = torch.max( | |
| torch.where( | |
| (t0[..., None, :] <= t_dilate[..., None]) | |
| & (t1[..., None, :] > t_dilate[..., None]), | |
| w[..., None, :], | |
| torch.zeros_like(w[..., None, :]), | |
| ), dim=-1).values[..., :-1] | |
| return t_dilate, w_dilate | |
| def max_dilate_weights(t, | |
| w, | |
| dilation, | |
| domain=(-torch.inf, torch.inf), | |
| renormalize=False): | |
| """Dilate (via max-pooling) a set of weights.""" | |
| eps = torch.finfo(w.dtype).eps | |
| # eps = 1e-3 | |
| p = weight_to_pdf(t, w) | |
| t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) | |
| w_dilate = pdf_to_weight(t_dilate, p_dilate) | |
| if renormalize: | |
| w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps) | |
| return t_dilate, w_dilate | |
| def integrate_weights(w): | |
| """Compute the cumulative sum of w, assuming all weight vectors sum to 1. | |
| The output's size on the last dimension is one greater than that of the input, | |
| because we're computing the integral corresponding to the endpoints of a step | |
| function, not the integral of the interior/bin values. | |
| Args: | |
| w: Tensor, which will be integrated along the last axis. This is assumed to | |
| sum to 1 along the last axis, and this function will (silently) break if | |
| that is not the case. | |
| Returns: | |
| cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 | |
| """ | |
| cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1) | |
| shape = cw.shape[:-1] + (1,) | |
| # Ensure that the CDF starts with exactly 0 and ends with exactly 1. | |
| cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw, | |
| torch.ones(shape, device=cw.device)], dim=-1) | |
| return cw0 | |
| def integrate_weights_np(w): | |
| """Compute the cumulative sum of w, assuming all weight vectors sum to 1. | |
| The output's size on the last dimension is one greater than that of the input, | |
| because we're computing the integral corresponding to the endpoints of a step | |
| function, not the integral of the interior/bin values. | |
| Args: | |
| w: Tensor, which will be integrated along the last axis. This is assumed to | |
| sum to 1 along the last axis, and this function will (silently) break if | |
| that is not the case. | |
| Returns: | |
| cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 | |
| """ | |
| cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) | |
| shape = cw.shape[:-1] + (1,) | |
| # Ensure that the CDF starts with exactly 0 and ends with exactly 1. | |
| cw0 = np.concatenate([np.zeros(shape), cw, | |
| np.ones(shape)], axis=-1) | |
| return cw0 | |
| def invert_cdf(u, t, w_logits): | |
| """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" | |
| # Compute the PDF and CDF for each weight vector. | |
| w = torch.softmax(w_logits, dim=-1) | |
| cw = integrate_weights(w) | |
| # Interpolate into the inverse CDF. | |
| t_new = math.sorted_interp(u, cw, t) | |
| return t_new | |
| def invert_cdf_np(u, t, w_logits): | |
| """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" | |
| # Compute the PDF and CDF for each weight vector. | |
| w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) | |
| cw = integrate_weights_np(w) | |
| # Interpolate into the inverse CDF. | |
| interp_fn = np.interp | |
| t_new = interp_fn(u, cw, t) | |
| return t_new | |
| def sample(rand, | |
| t, | |
| w_logits, | |
| num_samples, | |
| single_jitter=False, | |
| deterministic_center=False): | |
| """Piecewise-Constant PDF sampling from a step function. | |
| Args: | |
| rand: random number generator (or None for `linspace` sampling). | |
| t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) | |
| w_logits: [..., num_bins], logits corresponding to bin weights | |
| num_samples: int, the number of samples. | |
| single_jitter: bool, if True, jitter every sample along each ray by the same | |
| amount in the inverse CDF. Otherwise, jitter each sample independently. | |
| deterministic_center: bool, if False, when `rand` is None return samples that | |
| linspace the entire PDF. If True, skip the front and back of the linspace | |
| so that the centers of each PDF interval are returned. | |
| Returns: | |
| t_samples: [batch_size, num_samples]. | |
| """ | |
| eps = torch.finfo(t.dtype).eps | |
| # eps = 1e-3 | |
| device = t.device | |
| # Draw uniform samples. | |
| if not rand: | |
| if deterministic_center: | |
| pad = 1 / (2 * num_samples) | |
| u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device) | |
| else: | |
| u = torch.linspace(0, 1. - eps, num_samples, device=device) | |
| u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,)) | |
| else: | |
| # `u` is in [0, 1) --- it can be zero, but it can never be 1. | |
| u_max = eps + (1 - eps) / num_samples | |
| max_jitter = (1 - u_max) / (num_samples - 1) - eps | |
| d = 1 if single_jitter else num_samples | |
| u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \ | |
| torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter | |
| return invert_cdf(u, t, w_logits) | |
| def sample_np(rand, | |
| t, | |
| w_logits, | |
| num_samples, | |
| single_jitter=False, | |
| deterministic_center=False): | |
| """ | |
| numpy version of sample() | |
| """ | |
| eps = np.finfo(np.float32).eps | |
| # Draw uniform samples. | |
| if not rand: | |
| if deterministic_center: | |
| pad = 1 / (2 * num_samples) | |
| u = np.linspace(pad, 1. - pad - eps, num_samples) | |
| else: | |
| u = np.linspace(0, 1. - eps, num_samples) | |
| u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) | |
| else: | |
| # `u` is in [0, 1) --- it can be zero, but it can never be 1. | |
| u_max = eps + (1 - eps) / num_samples | |
| max_jitter = (1 - u_max) / (num_samples - 1) - eps | |
| d = 1 if single_jitter else num_samples | |
| u = np.linspace(0, 1 - u_max, num_samples) + \ | |
| np.random.rand(*t.shape[:-1], d) * max_jitter | |
| return invert_cdf_np(u, t, w_logits) | |
| def sample_intervals(rand, | |
| t, | |
| w_logits, | |
| num_samples, | |
| single_jitter=False, | |
| domain=(-torch.inf, torch.inf)): | |
| """Sample *intervals* (rather than points) from a step function. | |
| Args: | |
| rand: random number generator (or None for `linspace` sampling). | |
| t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) | |
| w_logits: [..., num_bins], logits corresponding to bin weights | |
| num_samples: int, the number of intervals to sample. | |
| single_jitter: bool, if True, jitter every sample along each ray by the same | |
| amount in the inverse CDF. Otherwise, jitter each sample independently. | |
| domain: (minval, maxval), the range of valid values for `t`. | |
| Returns: | |
| t_samples: [batch_size, num_samples]. | |
| """ | |
| if num_samples <= 1: | |
| raise ValueError(f'num_samples must be > 1, is {num_samples}.') | |
| # Sample a set of points from the step function. | |
| centers = sample( | |
| rand, | |
| t, | |
| w_logits, | |
| num_samples, | |
| single_jitter, | |
| deterministic_center=True) | |
| # The intervals we return will span the midpoints of each adjacent sample. | |
| mid = (centers[..., 1:] + centers[..., :-1]) / 2 | |
| # Each first/last fencepost is the reflection of the first/last midpoint | |
| # around the first/last sampled center. We clamp to the limits of the input | |
| # domain, provided by the caller. | |
| minval, maxval = domain | |
| first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval) | |
| last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval) | |
| t_samples = torch.cat([first, mid, last], dim=-1) | |
| return t_samples | |
| def lossfun_distortion(t, w): | |
| """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" | |
| # The loss incurred between all pairs of intervals. | |
| ut = (t[..., 1:] + t[..., :-1]) / 2 | |
| dut = torch.abs(ut[..., :, None] - ut[..., None, :]) | |
| loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) | |
| # The loss incurred within each individual interval with itself. | |
| loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 | |
| return loss_inter + loss_intra | |
| def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): | |
| """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" | |
| # Distortion when the intervals do not overlap. | |
| d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) | |
| # Distortion when the intervals overlap. | |
| d_overlap = (2 * | |
| (torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) + | |
| 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) + | |
| t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo * | |
| (t0_lo - t1_hi) + t1_lo * t0_hi * | |
| (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) | |
| # Are the two intervals not overlapping? | |
| are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) | |
| return torch.where(are_disjoint, d_disjoint, d_overlap) | |
| def weighted_percentile(t, w, ps): | |
| """Compute the weighted percentiles of a step function. w's must sum to 1.""" | |
| cw = integrate_weights(w) | |
| # We want to interpolate into the integrated weights according to `ps`. | |
| fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i) | |
| # Vmap fn to an arbitrary number of leading dimensions. | |
| cw_mat = cw.reshape([-1, cw.shape[-1]]) | |
| t_mat = t.reshape([-1, t.shape[-1]]) | |
| wprctile_mat = fn(cw_mat, t_mat) # TODO | |
| wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) | |
| return wprctile | |
| def resample(t, tp, vp, use_avg=False): | |
| """Resample a step function defined by (tp, vp) into intervals t. | |
| Args: | |
| t: tensor with shape (..., n+1), the endpoints to resample into. | |
| tp: tensor with shape (..., m+1), the endpoints of the step function being | |
| resampled. | |
| vp: tensor with shape (..., m), the values of the step function being | |
| resampled. | |
| use_avg: bool, if False, return the sum of the step function for each | |
| interval in `t`. If True, return the average, weighted by the width of | |
| each interval in `t`. | |
| eps: float, a small value to prevent division by zero when use_avg=True. | |
| Returns: | |
| v: tensor with shape (..., n), the values of the resampled step function. | |
| """ | |
| eps = torch.finfo(t.dtype).eps | |
| # eps = 1e-3 | |
| if use_avg: | |
| wp = torch.diff(tp, dim=-1) | |
| v_numer = resample(t, tp, vp * wp, use_avg=False) | |
| v_denom = resample(t, tp, wp, use_avg=False) | |
| v = v_numer / v_denom.clamp_min(eps) | |
| return v | |
| acc = torch.cumsum(vp, dim=-1) | |
| acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1) | |
| acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO | |
| v = torch.diff(acc0_resampled, dim=-1) | |
| return v | |
| def resample_np(t, tp, vp, use_avg=False): | |
| """ | |
| numpy version of resample | |
| """ | |
| eps = np.finfo(t.dtype).eps | |
| if use_avg: | |
| wp = np.diff(tp, axis=-1) | |
| v_numer = resample_np(t, tp, vp * wp, use_avg=False) | |
| v_denom = resample_np(t, tp, wp, use_avg=False) | |
| v = v_numer / np.maximum(eps, v_denom) | |
| return v | |
| acc = np.cumsum(vp, axis=-1) | |
| acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) | |
| acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) | |
| v = np.diff(acc0_resampled, axis=-1) | |
| return v | |
| def blur_stepfun(x, y, r): | |
| xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1)) | |
| y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) - | |
| torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r) | |
| y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1) | |
| yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) * | |
| torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0) | |
| yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1) | |
| return xr, yr |