| | import torch |
| |
|
| |
|
| | def transform_G_xyz(G, xyz, is_return_homo=False): |
| | """ |
| | |
| | :param G: Bx4x4 |
| | :param xyz: Bx3xN |
| | :return: |
| | """ |
| | assert len(G.size()) == len(xyz.size()) |
| | if len(G.size()) == 2: |
| | G_B44 = G.unsqueeze(0) |
| | xyz_B3N = xyz.unsqueeze(0) |
| | else: |
| | G_B44 = G |
| | xyz_B3N = xyz |
| | xyz_B4N = torch.cat((xyz_B3N, torch.ones_like(xyz_B3N[:, 0:1, :])), dim=1) |
| | G_xyz_B4N = torch.matmul(G_B44, xyz_B4N) |
| | if is_return_homo: |
| | return G_xyz_B4N |
| | else: |
| | return G_xyz_B4N[:, 0:3, :] |
| |
|
| |
|
| | def gather_pixel_by_pxpy(img, pxpy): |
| | """ |
| | |
| | :param img: Bx3xHxW |
| | :param pxpy: Bx2xN |
| | :return: |
| | """ |
| | with torch.no_grad(): |
| | B, C, H, W = img.size() |
| | if pxpy.dtype == torch.float32: |
| | pxpy_int = torch.round(pxpy).to(torch.int64) |
| | pxpy_int = pxpy_int.to(torch.int64) |
| | pxpy_int[:, 0, :] = torch.clamp(pxpy_int[:, 0, :], min=0, max=W-1) |
| | pxpy_int[:, 1, :] = torch.clamp(pxpy_int[:, 1, :], min=0, max=H-1) |
| | pxpy_idx = pxpy_int[:, 0:1, :] + W * pxpy_int[:, 1:2, :] |
| | rgb = torch.gather(img.view(B, C, H * W), dim=2, |
| | index=pxpy_idx.repeat(1, C, 1)) |
| | return rgb |
| |
|
| |
|
| | def uniformly_sample_disparity_from_bins(batch_size, disparity_np, device): |
| | """ |
| | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) |
| | :param start: |
| | :param end: |
| | :param num_bins: |
| | :return: |
| | """ |
| | assert disparity_np[0] > disparity_np[-1] |
| | S = disparity_np.shape[0] - 1 |
| |
|
| | B = batch_size |
| | bin_edges = torch.from_numpy(disparity_np).to(dtype=torch.float32, device=device) |
| | interval = bin_edges[1:] - bin_edges[0:-1] |
| | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) |
| | |
| | interval = interval.unsqueeze(0).repeat(B, 1) |
| |
|
| | random_float = torch.rand((B, S), dtype=torch.float32, device=device) |
| | disparity_array = bin_edges_start + interval * random_float |
| | return disparity_array |
| |
|
| |
|
| | def uniformly_sample_disparity_from_linspace_bins(batch_size, num_bins, start, end, device): |
| | """ |
| | In the disparity dimension, it has to be from large to small, i.e., depth from small (near) to large (far) |
| | :param start: |
| | :param end: |
| | :param num_bins: |
| | :return: |
| | """ |
| | assert start > end |
| |
|
| | B, S = batch_size, num_bins |
| | bin_edges = torch.linspace(start, end, num_bins+1, dtype=torch.float32, device=device) |
| | interval = bin_edges[1] - bin_edges[0] |
| | bin_edges_start = bin_edges[0:-1].unsqueeze(0).repeat(B, 1) |
| | |
| |
|
| | random_float = torch.rand((B, S), dtype=torch.float32, device=device) |
| | disparity_array = bin_edges_start + interval * random_float |
| | return disparity_array |
| |
|
| |
|
| | def sample_pdf(values, weights, N_samples): |
| | """ |
| | draw samples from distribution approximated by values and weights. |
| | the probability distribution can be denoted as weights = p(values) |
| | :param values: Bx1xNxS |
| | :param weights: Bx1xNxS |
| | :param N_samples: number of sample to draw |
| | :return: |
| | """ |
| | B, N, S = weights.size(0), weights.size(2), weights.size(3) |
| | assert values.size() == (B, 1, N, S) |
| |
|
| | |
| | bin_edges = (values[:, :, :, 1:] + values[:, :, :, :-1]) * 0.5 |
| | bin_edges = torch.cat((values[:, :, :, 0:1], |
| | bin_edges, |
| | values[:, :, :, -1:]), dim=3) |
| |
|
| | pdf = weights / (torch.sum(weights, dim=3, keepdim=True) + 1e-5) |
| | cdf = torch.cumsum(pdf, dim=3) |
| | cdf = torch.cat((torch.zeros((B, 1, N, 1), dtype=cdf.dtype, device=cdf.device), |
| | cdf), dim=3) |
| |
|
| | |
| | u = torch.rand((B, 1, N, N_samples), dtype=weights.dtype, device=weights.device) |
| |
|
| | |
| | cdf_idx = torch.searchsorted(cdf, u, right=True) |
| | cdf_idx_lower = torch.clamp(cdf_idx-1, min=0) |
| | cdf_idx_upper = torch.clamp(cdf_idx, max=S) |
| |
|
| | |
| | cdf_idx_lower_upper = torch.cat((cdf_idx_lower, cdf_idx_upper), dim=3) |
| | cdf_bounds_N2 = torch.gather(cdf, index=cdf_idx_lower_upper, dim=3) |
| | cdf_bounds = torch.stack((cdf_bounds_N2[..., 0:N_samples], cdf_bounds_N2[..., N_samples:]), dim=4) |
| | bin_bounds_N2 = torch.gather(bin_edges, index=cdf_idx_lower_upper, dim=3) |
| | bin_bounds = torch.stack((bin_bounds_N2[..., 0:N_samples], bin_bounds_N2[..., N_samples:]), dim=4) |
| |
|
| | |
| | cdf_intervals = cdf_bounds[:, :, :, :, 1] - cdf_bounds[:, :, :, :, 0] |
| | bin_intervals = bin_bounds[:, :, :, :, 1] - bin_bounds[:, :, :, :, 0] |
| | u_cdf_lower = u - cdf_bounds[:, :, :, :, 0] |
| | |
| | t = u_cdf_lower / torch.clamp(cdf_intervals, min=1e-5) |
| | t = torch.where(cdf_intervals <= 1e-4, |
| | torch.full_like(u_cdf_lower, 0.5), |
| | t) |
| |
|
| | samples = bin_bounds[:, :, :, :, 0] + t*bin_intervals |
| | return samples |
| |
|