| from typing import Tuple, Dict, Set | |
| import torch | |
| from torch import Tensor | |
| from jaxtyping import Float, Bool, UInt8, Int32 | |
| def inside_image( | |
| pts2d: Float[Tensor, "n 2"], | |
| image_size: Tuple[int, ...] | |
| ) -> Float[Tensor, " n"]: | |
| H, W = image_size | |
| px, py = pts2d.unbind(-1) | |
| return ( | |
| (0 <= px) & (px < W) & | |
| (0 <= py) & (py < H) | |
| ) | |
| def get_uv_grid( | |
| image_size: Tuple[int, int], | |
| dtype=torch.float32 | |
| ) -> Float[Tensor, "h w 2"]: | |
| H, W = image_size | |
| meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy") | |
| id_coords = torch.stack(meshgrid, dim=-1).to(dtype) | |
| return id_coords | |
| def persp_project(xyz): | |
| z = xyz[:, 2:] | |
| uv = xyz[:, :2] / z | |
| return uv, z | |