Spaces:
Paused
Paused
# MIT License | |
# Copyright (c) 2022 Petr Kellnhofer | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import torch | |
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: | |
""" | |
Left-multiplies MxM @ NxM. Returns NxM. | |
""" | |
res = torch.matmul(vectors4, matrix.T) | |
return res | |
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: | |
""" | |
Normalize vector lengths. | |
""" | |
return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) | |
def torch_dot(x: torch.Tensor, y: torch.Tensor): | |
""" | |
Dot product of two tensors. | |
""" | |
return (x * y).sum(-1) | |
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): | |
""" | |
Author: Petr Kellnhofer | |
Intersects rays with the [-1, 1] NDC volume. | |
Returns min and max distance of entry. | |
Returns -1 for no intersection. | |
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection | |
""" | |
o_shape = rays_o.shape | |
rays_o = rays_o.detach().reshape(-1, 3) | |
rays_d = rays_d.detach().reshape(-1, 3) | |
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] | |
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] | |
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) | |
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) | |
# Precompute inverse for stability. | |
invdir = 1 / rays_d | |
sign = (invdir < 0).long() | |
# Intersect with YZ plane. | |
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] | |
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] | |
# Intersect with XZ plane. | |
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] | |
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] | |
# Resolve parallel rays. | |
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False | |
# Use the shortest intersection. | |
tmin = torch.max(tmin, tymin) | |
tmax = torch.min(tmax, tymax) | |
# Intersect with XY plane. | |
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] | |
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] | |
# Resolve parallel rays. | |
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False | |
# Use the shortest intersection. | |
tmin = torch.max(tmin, tzmin) | |
tmax = torch.min(tmax, tzmax) | |
# Mark invalid. | |
tmin[torch.logical_not(is_valid)] = -1 | |
tmax[torch.logical_not(is_valid)] = -2 | |
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) | |
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): | |
""" | |
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. | |
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. | |
""" | |
# create a tensor of 'num' steps from 0 to 1 | |
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) | |
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings | |
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript | |
# "cannot statically infer the expected size of a list in this contex", hence the code below | |
for i in range(start.ndim): | |
steps = steps.unsqueeze(-1) | |
# the output starts at 'start' and increments until 'stop' in each dimension | |
out = start[None] + steps * (stop - start)[None] | |
return out | |