import os, torch import numpy as np import torch.nn.functional as F def build_patch_offset(h_patch_size): offsets = torch.arange(-h_patch_size, h_patch_size + 1) return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2 def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None): """ generate rays in world space, for image image :param H: :param W: :param intrinsics: [3,3] :param c2ws: [4,4] :return: """ device = image.device ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij' p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 # normalized ndc uv coordinates, (-1, 1) ndc_u = 2 * xs / (W - 1) - 1 ndc_v = 2 * ys / (H - 1) - 1 rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) intrinsic_inv = torch.inverse(intrinsic) p = p.view(-1, 3).float().to(device) # N_rays, 3 p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 image = image.permute(1, 2, 0) color = image.view(-1, 3) depth = depth.view(-1, 1) if depth is not None else None mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device) sample = { 'rays_o': rays_o, 'rays_v': rays_v, 'rays_ndc_uv': rays_ndc_uv, 'rays_color': color, # 'rays_depth': depth, 'rays_mask': mask, 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth } if depth is not None: sample['rays_depth'] = depth return sample def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None, importance_sample=False, h_patch_size=3): """ generate random rays in world space, for a single image :param H: :param W: :param N_rays: :param image: [3, H, W] :param intrinsic: [3,3] :param c2w: [4,4] :param depth: [H, W] :param mask: [H, W] :return: """ device = image.device if dilated_mask is None: dilated_mask = mask if not importance_sample: pixels_x = torch.randint(low=0, high=W, size=[N_rays]) pixels_y = torch.randint(low=0, high=H, size=[N_rays]) elif importance_sample and dilated_mask is not None: # sample more pts in the valid mask regions pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4]) pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4]) ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij' p = torch.stack([xs, ys], dim=-1) # H, W, 2 try: p_valid = p[dilated_mask > 0] # [num, 2] random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3]) except: print("dilated_mask.shape: ", dilated_mask.shape) print("dilated_mask valid number", dilated_mask.sum()) raise ValueError("hhhh") p_select = p_valid[random_idx] # [N_rays//2, 2] pixels_x_2 = p_select[:, 0] pixels_y_2 = p_select[:, 1] pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64) pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64) # - crop patch from images offsets = build_patch_offset(h_patch_size).to(device) grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() # [N_pts, Npx, 2] patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * ( pixels_y < H - h_patch_size) # [N_pts] grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1 grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1 grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) # [N_pts, Npx, 2] patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear', padding_mode='zeros',align_corners=True)[0] # [3, N_pts, Npx] patch_color = patch_color.permute(1, 2, 0).contiguous() # normalized ndc uv coordinates, (-1, 1) ndc_u = 2 * pixels_x / (W - 1) - 1 ndc_v = 2 * pixels_y / (H - 1) - 1 rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) image = image.permute(1, 2, 0) # H ,W, C color = image[(pixels_y, pixels_x)] # N_rays, 3 if mask is not None: mask = mask[(pixels_y, pixels_x)] # N_rays patch_mask = patch_mask * mask # N_rays mask = mask.view(-1, 1) else: mask = torch.ones([N_rays, 1]) if depth is not None: depth = depth[(pixels_y, pixels_x)] # N_rays depth = depth.view(-1, 1) intrinsic_inv = torch.inverse(intrinsic) p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays, 3 p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 sample = { 'rays_o': rays_o, 'rays_v': rays_v, 'rays_ndc_uv': rays_ndc_uv, 'rays_color': color, # 'rays_depth': depth, 'rays_mask': mask, 'rays_norm_XYZ_cam': p, # - XYZ_cam, before multiply depth, 'rays_patch_color': patch_color, 'rays_patch_mask': patch_mask.view(-1, 1) } if depth is not None: sample['rays_depth'] = depth return sample def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size, image, intrinsic, c2w, depth=None, mask=None): """ generate random rays in world space, for a single image sample rays from local patches :param H: :param W: :param N_rays: the number of center rays of patches :param image: [3, H, W] :param intrinsic: [3,3] :param c2w: [4,4] :param depth: [H, W] :param mask: [H, W] :return: """ device = image.device patch_radius_max = patch_size // 2 unit_u = 2 / (W - 1) unit_v = 2 / (H - 1) pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays]) pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays]) # normalized ndc uv coordinates, (-1, 1) ndc_u_center = 2 * pixels_x_center / (W - 1) - 1 ndc_v_center = 2 * pixels_y_center / (H - 1) - 1 ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None, :] # [N_rays, 1, 2] shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand( [N_rays, num_neighboring_pts, 1]) # uniform distribution of [0,1) shift_u = 2 * (shift_u - 0.5) # mapping to [-1, 1) shift_v = 2 * (shift_v - 0.5) # - avoid sample points which are too close to center point shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v], dim=-1) # [N_rays, num_npts, 2] neighboring_pts_uv = ndc_uv_center + shift_uv # [N_rays, num_npts, 2] sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) # concat the center point # sample the gts color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', align_corners=True)[0] # [3, N_rays, num_npts] depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', align_corners=True)[0] # [1, N_rays, num_npts] mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest', align_corners=True).to(torch.int64)[0] # [1, N_rays, num_npts] intrinsic_inv = torch.inverse(intrinsic) sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2) color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3) depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2 pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2 p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays*num_pts, 3 p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays*num_pts, 3 rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays*num_pts, 3 rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays*num_pts, 3 rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays*num_pts, 3 sample = { 'rays_o': rays_o, 'rays_v': rays_v, 'rays_ndc_uv': sampled_pts_uv, 'rays_color': color, 'rays_depth': depth, 'rays_mask': mask, # 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth } return sample def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None): """ :param H: :param W: :param N_rays: :param images: [B,3,H,W] :param intrinsics: [B, 3, 3] :param c2ws: [B, 4, 4] :param depths: [B,H,W] :param masks: [B,H,W] :return: """ assert len(images.shape) == 4 rays_o = [] rays_v = [] rays_color = [] rays_depth = [] rays_mask = [] for i in range(images.shape[0]): sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i], depth=depths[i] if depths is not None else None, mask=masks[i] if masks is not None else None) rays_o.append(sample['rays_o']) rays_v.append(sample['rays_v']) rays_color.append(sample['rays_color']) if depths is not None: rays_depth.append(sample['rays_depth']) if masks is not None: rays_mask.append(sample['rays_mask']) sample = { 'rays_o': torch.stack(rays_o, dim=0), # [batch, N_rays, 3] 'rays_v': torch.stack(rays_v, dim=0), 'rays_color': torch.stack(rays_color, dim=0), 'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None, 'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None } return sample from scipy.spatial.transform import Rotation as Rot from scipy.spatial.transform import Slerp def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1): device = c2w_0.device l = resolution_level tx = torch.linspace(0, W - 1, W // l) ty = torch.linspace(0, H - 1, H // l) pixels_x, pixels_y = torch.meshgrid(tx, ty, indexing="ij") p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) # W, H, 3 intrinsic_inv = torch.inverse(intrinsic[:3, :3]) p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio pose_0 = c2w_0.detach().cpu().numpy() pose_1 = c2w_1.detach().cpu().numpy() pose_0 = np.linalg.inv(pose_0) pose_1 = np.linalg.inv(pose_1) rot_0 = pose_0[:3, :3] rot_1 = pose_1[:3, :3] rots = Rot.from_matrix(np.stack([rot_0, rot_1])) key_times = [0, 1] key_rots = [rot_0, rot_1] slerp = Slerp(key_times, rots) rot = slerp(ratio) pose = np.diag([1.0, 1.0, 1.0, 1.0]) pose = pose.astype(np.float32) pose[:3, :3] = rot.as_matrix() pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] pose = np.linalg.inv(pose) c2w = torch.from_numpy(pose).to(device) rot = torch.from_numpy(pose[:3, :3]).cuda() trans = torch.from_numpy(pose[:3, 3]).cuda() rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3)