JiantaoLin
new
2fe3da0
raw
history blame
3.75 kB
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.
"""
The renderer is a module that takes in rays, decides where to sample along each
ray, and computes pixel colors using the volume rendering equation.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import math_utils
def generate_planes():
"""
Defines planes by the three vectors that form the "axes" of the
plane. Should work with arbitrary number of planes and planes of
arbitrary orientation.
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
"""
return torch.tensor([[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]],
[[1, 0, 0],
[0, 0, 1],
[0, 1, 0]],
[[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]], dtype=torch.float32)
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Takes plane axes of shape n_planes, 3, 3
# Takes coordinates of shape N, M, 3
# returns projections of shape N*n_planes, M, 2
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
assert padding_mode == 'zeros'
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N*n_planes, C, H, W)
dtype = plane_features.dtype
coordinates = (2/box_warp) * coordinates # add specific box bounds
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
output_features = torch.nn.functional.grid_sample(
plane_features,
projected_coordinates.to(dtype),
mode=mode,
padding_mode=padding_mode,
align_corners=False,
).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
return output_features
def sample_from_3dgrid(grid, coordinates):
"""
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
Expects grid in shape (1, channels, H, W, D)
(Also works if grid has batch size)
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
"""
batch_size, n_coords, n_dims = coordinates.shape
sampled_features = torch.nn.functional.grid_sample(
grid.expand(batch_size, -1, -1, -1, -1),
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
mode='bilinear',
padding_mode='zeros',
align_corners=False,
)
N, C, H, W, D = sampled_features.shape
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
return sampled_features