| |
| |
| |
| |
| |
|
|
| |
|
|
| import warnings |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| from pytorch3d.common.compat import meshgrid_ij |
| from pytorch3d.ops import padded_to_packed |
| from pytorch3d.renderer.cameras import CamerasBase |
| from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle |
| from torch.nn import functional as F |
|
|
|
|
| """ |
| This file defines three raysampling techniques: |
| - MultinomialRaysampler which can be used to sample rays from pixels of an image grid |
| - NDCMultinomialRaysampler which can be used to sample rays from pixels of an image grid, |
| which follows the pytorch3d convention for image grid coordinates |
| - MonteCarloRaysampler which randomly selects real-valued locations in the image plane |
| and emits rays from them |
| """ |
|
|
|
|
| class MultinomialRaysampler(torch.nn.Module): |
| """ |
| Samples a fixed number of points along rays which are regularly distributed |
| in a batch of rectangular image grids. Points along each ray |
| have uniformly-spaced z-coordinates between a predefined |
| minimum and maximum depth. |
| |
| The raysampler first generates a 3D coordinate grid of the following form:: |
| |
| / min_x, min_y, max_depth -------------- / max_x, min_y, max_depth |
| / /| |
| / / | ^ |
| / min_depth min_depth / | | |
| min_x ----------------------------- max_x | | image |
| min_y min_y | | height |
| | | | | |
| | | | v |
| | | | |
| | | / max_x, max_y, ^ |
| | | / max_depth / |
| min_x max_y / / n_pts_per_ray |
| max_y ----------------------------- max_x/ min_depth v |
| < --- image_width --- > |
| |
| In order to generate ray points, `MultinomialRaysampler` takes each 3D point of |
| the grid (with coordinates `[x, y, depth]`) and unprojects it |
| with `cameras.unproject_points([x, y, depth])`, where `cameras` are an |
| additional input to the `forward` function. |
| |
| Note that this is a generic implementation that can support any image grid |
| coordinate convention. For a raysampler which follows the PyTorch3D |
| coordinate conventions please refer to `NDCMultinomialRaysampler`. |
| As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`. |
| |
| Attributes: |
| min_x: The leftmost x-coordinate of each ray's source pixel's center. |
| max_x: The rightmost x-coordinate of each ray's source pixel's center. |
| min_y: The topmost y-coordinate of each ray's source pixel's center. |
| max_y: The bottommost y-coordinate of each ray's source pixel's center. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| min_x: float, |
| max_x: float, |
| min_y: float, |
| max_y: float, |
| image_width: int, |
| image_height: int, |
| n_pts_per_ray: int, |
| min_depth: float, |
| max_depth: float, |
| n_rays_per_image: Optional[int] = None, |
| n_rays_total: Optional[int] = None, |
| unit_directions: bool = False, |
| stratified_sampling: bool = False, |
| ) -> None: |
| """ |
| Args: |
| min_x: The leftmost x-coordinate of each ray's source pixel's center. |
| max_x: The rightmost x-coordinate of each ray's source pixel's center. |
| min_y: The topmost y-coordinate of each ray's source pixel's center. |
| max_y: The bottommost y-coordinate of each ray's source pixel's center. |
| image_width: The horizontal size of the image grid. |
| image_height: The vertical size of the image grid. |
| n_pts_per_ray: The number of points sampled along each ray. |
| min_depth: The minimum depth of a ray-point. |
| max_depth: The maximum depth of a ray-point. |
| n_rays_per_image: If given, this amount of rays are sampled from the grid. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| n_rays_total: How many rays in total to sample from the cameras provided. The result |
| is as if `n_rays_total_training` cameras were sampled with replacement from the |
| cameras provided and for every camera one ray was sampled. If set returns the |
| HeterogeneousRayBundle with batch_size=n_rays_total. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| unit_directions: whether to normalize direction vectors in ray bundle. |
| stratified_sampling: if True, performs stratified random sampling |
| along the ray; otherwise takes ray points at deterministic offsets. |
| """ |
| super().__init__() |
| self._n_pts_per_ray = n_pts_per_ray |
| self._min_depth = min_depth |
| self._max_depth = max_depth |
| self._n_rays_per_image = n_rays_per_image |
| self._n_rays_total = n_rays_total |
| self._unit_directions = unit_directions |
| self._stratified_sampling = stratified_sampling |
| self.min_x, self.max_x = min_x, max_x |
| self.min_y, self.max_y = min_y, max_y |
| |
| y, x = meshgrid_ij( |
| torch.linspace(min_y, max_y, image_height, dtype=torch.float32), |
| torch.linspace(min_x, max_x, image_width, dtype=torch.float32), |
| ) |
| _xy_grid = torch.stack([x, y], dim=-1) |
|
|
| self.register_buffer("_xy_grid", _xy_grid, persistent=False) |
|
|
| def forward( |
| self, |
| cameras: CamerasBase, |
| *, |
| mask: Optional[torch.Tensor] = None, |
| min_depth: Optional[float] = None, |
| max_depth: Optional[float] = None, |
| n_rays_per_image: Optional[int] = None, |
| n_pts_per_ray: Optional[int] = None, |
| stratified_sampling: Optional[bool] = None, |
| n_rays_total: Optional[int] = None, |
| **kwargs, |
| ) -> Union[RayBundle, HeterogeneousRayBundle]: |
| """ |
| Args: |
| cameras: A batch of `batch_size` cameras from which the rays are emitted. |
| mask: if given, the rays are sampled from the mask. Should be of size |
| (batch_size, image_height, image_width). |
| min_depth: The minimum depth of a ray-point. |
| max_depth: The maximum depth of a ray-point. |
| n_rays_per_image: If given, this amount of rays are sampled from the grid. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| n_pts_per_ray: The number of points sampled along each ray. |
| stratified_sampling: if set, overrides stratified_sampling provided |
| in __init__. |
| n_rays_total: How many rays in total to sample from the cameras provided. The result |
| is as if `n_rays_total_training` cameras were sampled with replacement from the |
| cameras provided and for every camera one ray was sampled. If set returns the |
| HeterogeneousRayBundle with batch_size=n_rays_total. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| Returns: |
| A named tuple RayBundle or dataclass HeterogeneousRayBundle with the |
| following fields: |
| |
| origins: A tensor of shape |
| `(batch_size, s1, s2, 3)` |
| denoting the locations of ray origins in the world coordinates. |
| directions: A tensor of shape |
| `(batch_size, s1, s2, 3)` |
| denoting the directions of each ray in the world coordinates. |
| lengths: A tensor of shape |
| `(batch_size, s1, s2, n_pts_per_ray)` |
| containing the z-coordinate (=depth) of each ray in world units. |
| xys: A tensor of shape |
| `(batch_size, s1, s2, 2)` |
| containing the 2D image coordinates of each ray or, |
| if mask is given, `(batch_size, n, 1, 2)` |
| Here `s1, s2` refer to spatial dimensions. |
| `(s1, s2)` refer to (highest priority first): |
| - `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total) |
| - `(n_rays_per_image, 1) if `n_rays_per_image` if provided, |
| - `(n, 1)` where n is the minimum cardinality of the mask |
| in the batch if `mask` is provided |
| - `(image_height, image_width)` if nothing from above is satisfied |
| |
| `HeterogeneousRayBundle` has additional members: |
| - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled |
| cameras. It represents unique ids of sampled cameras. |
| - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled |
| cameras. Represents how many times each camera from `camera_ids` was sampled |
| |
| `HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle` |
| is returned. |
| """ |
| n_rays_total = n_rays_total or self._n_rays_total |
| n_rays_per_image = n_rays_per_image or self._n_rays_per_image |
| if (n_rays_total is not None) and (n_rays_per_image is not None): |
| raise ValueError( |
| "`n_rays_total` and `n_rays_per_image` cannot both be defined." |
| ) |
| if n_rays_total: |
| ( |
| cameras, |
| mask, |
| camera_ids, |
| camera_counts, |
| |
| |
| |
| |
| n_rays_per_image, |
| ) = _sample_cameras_and_masks(n_rays_total, cameras, mask) |
| else: |
| |
| camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long) |
|
|
| batch_size = cameras.R.shape[0] |
| device = cameras.device |
|
|
| |
| xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1) |
|
|
| if mask is not None and n_rays_per_image is None: |
| |
| n_rays_per_image = ( |
| n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item() |
| ) |
|
|
| if n_rays_per_image is not None: |
| if mask is not None: |
| assert mask.shape == xy_grid.shape[:3] |
| weights = mask.reshape(batch_size, -1) |
| else: |
| |
| |
| |
| _, width, height, _ = xy_grid.shape |
| weights = xy_grid.new_ones(batch_size, width * height) |
| |
| |
| rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand( |
| -1, -1, 2 |
| ) |
|
|
| xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[ |
| :, :, None |
| ] |
|
|
| min_depth = min_depth if min_depth is not None else self._min_depth |
| max_depth = max_depth if max_depth is not None else self._max_depth |
| n_pts_per_ray = ( |
| n_pts_per_ray if n_pts_per_ray is not None else self._n_pts_per_ray |
| ) |
| stratified_sampling = ( |
| stratified_sampling |
| if stratified_sampling is not None |
| else self._stratified_sampling |
| ) |
|
|
| ray_bundle = _xy_to_ray_bundle( |
| cameras, |
| xy_grid, |
| min_depth, |
| max_depth, |
| n_pts_per_ray, |
| self._unit_directions, |
| stratified_sampling, |
| ) |
|
|
| return ( |
| |
| _pack_ray_bundle(ray_bundle, camera_ids, camera_counts) |
| if n_rays_total |
| else ray_bundle |
| ) |
|
|
|
|
| class NDCMultinomialRaysampler(MultinomialRaysampler): |
| """ |
| Samples a fixed number of points along rays which are regularly distributed |
| in a batch of rectangular image grids. Points along each ray |
| have uniformly-spaced z-coordinates between a predefined minimum and maximum depth. |
| |
| `NDCMultinomialRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds` |
| renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] |
| where u > 1 is the aspect ratio of the image. |
| |
| For the description of arguments, see the documentation to MultinomialRaysampler. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| image_width: int, |
| image_height: int, |
| n_pts_per_ray: int, |
| min_depth: float, |
| max_depth: float, |
| n_rays_per_image: Optional[int] = None, |
| n_rays_total: Optional[int] = None, |
| unit_directions: bool = False, |
| stratified_sampling: bool = False, |
| ) -> None: |
| if image_width >= image_height: |
| range_x = image_width / image_height |
| range_y = 1.0 |
| else: |
| range_x = 1.0 |
| range_y = image_height / image_width |
|
|
| half_pix_width = range_x / image_width |
| half_pix_height = range_y / image_height |
| super().__init__( |
| min_x=range_x - half_pix_width, |
| max_x=-range_x + half_pix_width, |
| min_y=range_y - half_pix_height, |
| max_y=-range_y + half_pix_height, |
| image_width=image_width, |
| image_height=image_height, |
| n_pts_per_ray=n_pts_per_ray, |
| min_depth=min_depth, |
| max_depth=max_depth, |
| n_rays_per_image=n_rays_per_image, |
| n_rays_total=n_rays_total, |
| unit_directions=unit_directions, |
| stratified_sampling=stratified_sampling, |
| ) |
|
|
|
|
| class MonteCarloRaysampler(torch.nn.Module): |
| """ |
| Samples a fixed number of pixels within denoted xy bounds uniformly at random. |
| For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced |
| z-coordinates such that the z-coordinates range between a predefined minimum |
| and maximum depth. |
| |
| For practical purposes, this is similar to MultinomialRaysampler without a mask, |
| however sampling at real-valued locations bypassing replacement checks may be faster. |
| """ |
|
|
| def __init__( |
| self, |
| min_x: float, |
| max_x: float, |
| min_y: float, |
| max_y: float, |
| n_rays_per_image: int, |
| n_pts_per_ray: int, |
| min_depth: float, |
| max_depth: float, |
| *, |
| n_rays_total: Optional[int] = None, |
| unit_directions: bool = False, |
| stratified_sampling: bool = False, |
| ) -> None: |
| """ |
| Args: |
| min_x: The smallest x-coordinate of each ray's source pixel. |
| max_x: The largest x-coordinate of each ray's source pixel. |
| min_y: The smallest y-coordinate of each ray's source pixel. |
| max_y: The largest y-coordinate of each ray's source pixel. |
| n_rays_per_image: The number of rays randomly sampled in each camera. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| n_pts_per_ray: The number of points sampled along each ray. |
| min_depth: The minimum depth of each ray-point. |
| max_depth: The maximum depth of each ray-point. |
| n_rays_total: How many rays in total to sample from the cameras provided. The result |
| is as if `n_rays_total_training` cameras were sampled with replacement from the |
| cameras provided and for every camera one ray was sampled. If set returns the |
| HeterogeneousRayBundle with batch_size=n_rays_total. |
| `n_rays_per_image` and `n_rays_total` cannot both be defined. |
| unit_directions: whether to normalize direction vectors in ray bundle. |
| stratified_sampling: if True, performs stratified sampling in n_pts_per_ray |
| bins for each ray; otherwise takes n_pts_per_ray deterministic points |
| on each ray with uniform offsets. |
| """ |
| super().__init__() |
| self._min_x = min_x |
| self._max_x = max_x |
| self._min_y = min_y |
| self._max_y = max_y |
| self._n_rays_per_image = n_rays_per_image |
| self._n_pts_per_ray = n_pts_per_ray |
| self._min_depth = min_depth |
| self._max_depth = max_depth |
| self._n_rays_total = n_rays_total |
| self._unit_directions = unit_directions |
| self._stratified_sampling = stratified_sampling |
|
|
| def forward( |
| self, |
| cameras: CamerasBase, |
| *, |
| stratified_sampling: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[RayBundle, HeterogeneousRayBundle]: |
| """ |
| Args: |
| cameras: A batch of `batch_size` cameras from which the rays are emitted. |
| stratified_sampling: if set, overrides stratified_sampling provided |
| in __init__. |
| Returns: |
| A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the |
| following fields: |
| |
| origins: A tensor of shape |
| `(batch_size, n_rays_per_image, 3)` |
| denoting the locations of ray origins in the world coordinates. |
| directions: A tensor of shape |
| `(batch_size, n_rays_per_image, 3)` |
| denoting the directions of each ray in the world coordinates. |
| lengths: A tensor of shape |
| `(batch_size, n_rays_per_image, n_pts_per_ray)` |
| containing the z-coordinate (=depth) of each ray in world units. |
| xys: A tensor of shape |
| `(batch_size, n_rays_per_image, 2)` |
| containing the 2D image coordinates of each ray. |
| If `n_rays_total` is provided `batch_size=n_rays_total`and |
| `n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle` |
| is returned. |
| |
| `HeterogeneousRayBundle` has additional members: |
| - camera_ids: tensor of shape (M,), where `M` is the number of unique sampled |
| cameras. It represents unique ids of sampled cameras. |
| - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled |
| cameras. Represents how many times each camera from `camera_ids` was sampled |
| """ |
| if ( |
| sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image]) |
| != 1 |
| ): |
| raise ValueError( |
| "Exactly one of `self.n_rays_total` and `self.n_rays_per_image` " |
| "must be given." |
| ) |
|
|
| if self._n_rays_total: |
| ( |
| cameras, |
| _, |
| camera_ids, |
| camera_counts, |
| n_rays_per_image, |
| ) = _sample_cameras_and_masks(self._n_rays_total, cameras, None) |
| else: |
| |
| camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long) |
| n_rays_per_image = self._n_rays_per_image |
|
|
| batch_size = cameras.R.shape[0] |
|
|
| device = cameras.device |
|
|
| |
| |
| rays_xy = torch.cat( |
| [ |
| torch.rand( |
| size=(batch_size, n_rays_per_image, 1), |
| dtype=torch.float32, |
| device=device, |
| ) |
| * (high - low) |
| + low |
| for low, high in ( |
| (self._min_x, self._max_x), |
| (self._min_y, self._max_y), |
| ) |
| ], |
| dim=2, |
| ) |
|
|
| stratified_sampling = ( |
| stratified_sampling |
| if stratified_sampling is not None |
| else self._stratified_sampling |
| ) |
|
|
| ray_bundle = _xy_to_ray_bundle( |
| cameras, |
| rays_xy, |
| self._min_depth, |
| self._max_depth, |
| self._n_pts_per_ray, |
| self._unit_directions, |
| stratified_sampling, |
| ) |
|
|
| return ( |
| |
| _pack_ray_bundle(ray_bundle, camera_ids, camera_counts) |
| if self._n_rays_total |
| else ray_bundle |
| ) |
|
|
|
|
| |
| def GridRaysampler( |
| min_x: float, |
| max_x: float, |
| min_y: float, |
| max_y: float, |
| image_width: int, |
| image_height: int, |
| n_pts_per_ray: int, |
| min_depth: float, |
| max_depth: float, |
| ) -> "MultinomialRaysampler": |
| """ |
| GridRaysampler has been DEPRECATED. Use MultinomialRaysampler instead. |
| Preserving GridRaysampler for backward compatibility. |
| """ |
|
|
| warnings.warn( |
| """GridRaysampler is deprecated, |
| Use MultinomialRaysampler instead. |
| GridRaysampler will be removed in future releases.""", |
| PendingDeprecationWarning, |
| ) |
|
|
| return MultinomialRaysampler( |
| min_x=min_x, |
| max_x=max_x, |
| min_y=min_y, |
| max_y=max_y, |
| image_width=image_width, |
| image_height=image_height, |
| n_pts_per_ray=n_pts_per_ray, |
| min_depth=min_depth, |
| max_depth=max_depth, |
| ) |
|
|
|
|
| |
| def NDCGridRaysampler( |
| image_width: int, |
| image_height: int, |
| n_pts_per_ray: int, |
| min_depth: float, |
| max_depth: float, |
| ) -> "NDCMultinomialRaysampler": |
| """ |
| NDCGridRaysampler has been DEPRECATED. Use NDCMultinomialRaysampler instead. |
| Preserving NDCGridRaysampler for backward compatibility. |
| """ |
|
|
| warnings.warn( |
| """NDCGridRaysampler is deprecated, |
| Use NDCMultinomialRaysampler instead. |
| NDCGridRaysampler will be removed in future releases.""", |
| PendingDeprecationWarning, |
| ) |
|
|
| return NDCMultinomialRaysampler( |
| image_width=image_width, |
| image_height=image_height, |
| n_pts_per_ray=n_pts_per_ray, |
| min_depth=min_depth, |
| max_depth=max_depth, |
| ) |
|
|
|
|
| def _safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor: |
| """ |
| Wrapper around torch.multinomial that attempts sampling without replacement |
| when possible, otherwise resorts to sampling with replacement. |
| |
| Args: |
| input: tensor of shape [B, n] containing non-negative values; |
| rows are interpreted as unnormalized event probabilities |
| in categorical distributions. |
| num_samples: number of samples to take. |
| |
| Returns: |
| LongTensor of shape [B, num_samples] containing |
| values from {0, ..., n - 1} where the elements [i, :] of row i make |
| (1) if there are num_samples or more non-zero values in input[i], |
| a random subset of the indices of those values, with |
| probabilities proportional to the values in input[i, :]. |
| |
| (2) if not, a random sample with replacement of the indices of |
| those values, with probabilities proportional to them. |
| This sample might not contain all the indices of the |
| non-zero values. |
| Behavior undetermined if there are no non-zero values in a whole row |
| or if there are negative values. |
| """ |
| try: |
| res = torch.multinomial(input, num_samples, replacement=False) |
| except RuntimeError: |
| |
| res = torch.multinomial(input, num_samples, replacement=True) |
| no_repl = (input > 0.0).sum(dim=-1) >= num_samples |
| res[no_repl] = torch.multinomial(input[no_repl], num_samples, replacement=False) |
| return res |
|
|
| |
| |
| repl = (input > 0.0).sum(dim=-1) < num_samples |
| if repl.any(): |
| res[repl] = torch.multinomial(input[repl], num_samples, replacement=True) |
|
|
| return res |
|
|
|
|
| def _xy_to_ray_bundle( |
| cameras: CamerasBase, |
| xy_grid: torch.Tensor, |
| min_depth: float, |
| max_depth: float, |
| n_pts_per_ray: int, |
| unit_directions: bool, |
| stratified_sampling: bool = False, |
| ) -> RayBundle: |
| """ |
| Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays. |
| This adds to each xy location in the grid a vector of `n_pts_per_ray` depths |
| uniformly spaced between `min_depth` and `max_depth`. |
| |
| The extended grid is then unprojected with `cameras` to yield |
| ray origins, directions and depths. |
| |
| Args: |
| cameras: cameras object representing a batch of cameras. |
| xy_grid: torch.tensor grid of image xy coords. |
| min_depth: The minimum depth of each ray-point. |
| max_depth: The maximum depth of each ray-point. |
| n_pts_per_ray: The number of points sampled along each ray. |
| unit_directions: whether to normalize direction vectors in ray bundle. |
| stratified_sampling: if True, performs stratified sampling in n_pts_per_ray |
| bins for each ray; otherwise takes n_pts_per_ray deterministic points |
| on each ray with uniform offsets. |
| """ |
| batch_size = xy_grid.shape[0] |
| spatial_size = xy_grid.shape[1:-1] |
| n_rays_per_image = spatial_size.numel() |
|
|
| |
| rays_zs = xy_grid.new_empty((0,)) |
| if n_pts_per_ray > 0: |
| depths = torch.linspace( |
| min_depth, |
| max_depth, |
| n_pts_per_ray, |
| dtype=xy_grid.dtype, |
| device=xy_grid.device, |
| ) |
| rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray) |
|
|
| if stratified_sampling: |
| rays_zs = _jiggle_within_stratas(rays_zs) |
|
|
| |
| to_unproject = torch.cat( |
| ( |
| xy_grid.view(batch_size, 1, n_rays_per_image, 2) |
| .expand(batch_size, 2, n_rays_per_image, 2) |
| .reshape(batch_size, n_rays_per_image * 2, 2), |
| torch.cat( |
| ( |
| xy_grid.new_ones(batch_size, n_rays_per_image, 1), |
| 2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1), |
| ), |
| dim=1, |
| ), |
| ), |
| dim=-1, |
| ) |
|
|
| |
| unprojected = cameras.unproject_points(to_unproject, from_ndc=True) |
|
|
| |
| rays_plane_1_world = unprojected[:, :n_rays_per_image] |
| rays_plane_2_world = unprojected[:, n_rays_per_image:] |
|
|
| |
| rays_directions_world = rays_plane_2_world - rays_plane_1_world |
|
|
| |
| rays_origins_world = rays_plane_1_world - rays_directions_world |
|
|
| if unit_directions: |
| rays_directions_world = F.normalize(rays_directions_world, dim=-1) |
|
|
| return RayBundle( |
| rays_origins_world.view(batch_size, *spatial_size, 3), |
| rays_directions_world.view(batch_size, *spatial_size, 3), |
| rays_zs.view(batch_size, *spatial_size, n_pts_per_ray), |
| xy_grid, |
| ) |
|
|
|
|
| def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor: |
| """ |
| Performs sampling of 1 point per bin given the bin centers. |
| |
| More specifically, it replaces each point's value `z` |
| with a sample from a uniform random distribution on |
| `[z - delta_-, z + delta_+]`, where `delta_-` is half of the difference |
| between `z` and the previous point, and `delta_+` is half of the difference |
| between the next point and `z`. For the first and last items, the |
| corresponding boundary deltas are assumed zero. |
| |
| Args: |
| `bin_centers`: The input points of size (..., N); the result is broadcast |
| along all but the last dimension (the rows). Each row should be |
| sorted in ascending order. |
| |
| Returns: |
| a tensor of size (..., N) with the locations jiggled within stratas/bins. |
| """ |
| |
| mids = 0.5 * (bin_centers[..., 1:] + bin_centers[..., :-1]) |
| upper = torch.cat((mids, bin_centers[..., -1:]), dim=-1) |
| lower = torch.cat((bin_centers[..., :1], mids), dim=-1) |
| |
| jiggled = lower + (upper - lower) * torch.rand_like(lower) |
| return jiggled |
|
|
|
|
| def _sample_cameras_and_masks( |
| n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None |
| ) -> Tuple[ |
| CamerasBase, |
| Optional[torch.Tensor], |
| torch.LongTensor, |
| torch.LongTensor, |
| torch.LongTensor, |
| ]: |
| """ |
| Samples n_rays_total cameras and masks and returns them in a form |
| (camera_idx, count), where count represents number of times the same camera |
| has been sampled. |
| |
| Args: |
| n_samples: how many camera and mask pairs to sample |
| cameras: A batch of `batch_size` cameras from which the rays are emitted. |
| mask: Optional. Should be of size (batch_size, image_height, image_width). |
| Returns: |
| tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids, |
| number_of_times_each_sampled_camera_has_been_sampled, |
| max_number_of_times_camera_has_been_sampled, |
| ) |
| """ |
| sampled_ids = torch.randint( |
| 0, |
| len(cameras), |
| size=(n_samples,), |
| dtype=torch.long, |
| ) |
| unique_ids, counts = torch.unique(sampled_ids, return_counts=True) |
| |
| return ( |
| cameras[unique_ids], |
| mask[unique_ids] if mask is not None else None, |
| unique_ids, |
| counts, |
| torch.max(counts), |
| ) |
|
|
|
|
| |
| def _pack_ray_bundle( |
| ray_bundle: RayBundle, camera_ids: torch.LongTensor, camera_counts: torch.LongTensor |
| ) -> HeterogeneousRayBundle: |
| """ |
| Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to |
| [total_num_rays, 1, ...] |
| |
| Args: |
| ray_bundle: A ray_bundle to pack |
| camera_ids: Unique ids of cameras that were sampled |
| camera_counts: how many of which camera to pack, each count coresponds to |
| one 'row' of the ray_bundle and says how many rays wll be taken |
| from it and packed. |
| Returns: |
| HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1 |
| """ |
| |
| camera_counts = camera_counts.to(ray_bundle.origins.device) |
| cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long) |
| |
| first_idxs: torch.LongTensor = torch.cat( |
| (camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1]) |
| ) |
| num_inputs = int(camera_counts.sum()) |
|
|
| return HeterogeneousRayBundle( |
| origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None], |
| directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[ |
| :, None |
| ], |
| lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None], |
| xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None], |
| camera_ids=camera_ids, |
| camera_counts=camera_counts, |
| ) |
|
|