FQiao's picture
Upload 144 files
b4754db verified
from typing import Dict
from jaxtyping import Float
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from einops import rearrange
def sph2cart(
azi: Float[Tensor, 'B'],
ele: Float[Tensor, 'B'],
r: Float[Tensor, 'B']
) -> Float[Tensor, 'B 3']:
# z-up, y-right, x-back
rcos = r * torch.cos(ele)
pos_cart = torch.stack([
rcos * torch.cos(azi),
rcos * torch.sin(azi),
r * torch.sin(ele)
], dim=1)
return pos_cart
def get_viewport_matrix(
width: int,
height: int,
batch_size: int=1,
device: torch.device=None,
) -> Float[Tensor, 'B 4 4']:
N = torch.tensor(
[[width/2, 0, 0, width/2],
[0, height/2, 0, height/2],
[0, 0, 1/2, 1/2],
[0, 0, 0, 1]],
dtype=torch.float32,
device=device
)[None].repeat(batch_size, 1, 1)
return N
def get_projection_matrix(
fovy: Float[Tensor, 'B'],
aspect_wh: float,
near: float,
far: float
) -> Float[Tensor, 'B 4 4']:
batch_size = fovy.shape[0]
proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
proj_mtx[:, 1, 1] = -1.0 / torch.tan(
fovy / 2.0
) # add a negative sign here as the y axis is flipped in nvdiffrast output
proj_mtx[:, 2, 2] = -(far + near) / (far - near)
proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
proj_mtx[:, 3, 2] = -1.0
return proj_mtx
def camera_lookat(
eye: Float[Tensor, 'B 3'],
target: Float[Tensor, 'B 3'],
up: Float[Tensor, 'B 3']
) -> Float[Tensor, 'B 4 4']:
B = eye.shape[0]
f = F.normalize(eye - target)
l = F.normalize(torch.linalg.cross(up, f))
u = F.normalize(torch.linalg.cross(f, l))
R = torch.stack((l, u, f), dim=1) # B 3 3
M_R = torch.eye(4, dtype=torch.float32)[None].repeat((B, 1, 1))
M_R[..., :3, :3] = R
T = - eye
M_T = torch.eye(4, dtype=torch.float32)[None].repeat((B, 1, 1))
M_T[..., :3, 3] = T
return (M_R @ M_T).to(dtype=torch.float32)
def focal_length_to_fov(
focal_length: float,
censor_length: float = 24.
) -> float:
return 2 * np.arctan(censor_length / focal_length / 2.)
def forward_warper(
image: Float[Tensor, 'B C H W'],
screen: Float[Tensor, 'B (H W) 2'],
pcd: Float[Tensor, 'B (H W) 4'],
mvp_mtx: Float[Tensor, 'B 4 4'],
viewport_mtx: Float[Tensor, 'B 4 4'],
alpha: float = 0.5
) -> Dict[str, Tensor]:
H, W = image.shape[2:4]
# Projection.
points_c = pcd @ mvp_mtx.mT
points_ndc = points_c / points_c[..., 3:4]
# To screen.
coords_new = points_ndc @ viewport_mtx.mT
# Masking invalid pixels.
invalid = coords_new[..., 2] <= 0
coords_new[invalid] = -1000000 if coords_new.dtype == torch.float32 else -1e+4
new_z = points_c[..., 2:3]
flow = coords_new[..., :2] - screen[..., :2]
## Importance.
importance = alpha / new_z
importance -= importance.amin((1, 2), keepdim=True)
importance /= importance.amax((1, 2), keepdim=True) + 1e-6
importance = importance * 10 - 10
## Rearrange.
importance = rearrange(importance, 'b (h w) c -> b c h w', h=H, w=W)
flow = rearrange(flow, 'b (h w) c -> b c h w', h=H, w=W)
# Splatting.
warped = splatting_function('softmax', image, flow, importance, eps=1e-6)
## mask is 1 where there is no splat
mask = (warped == 0.).all(dim=1, keepdim=True).to(image.dtype)
flow2 = rearrange(coords_new[..., :2], 'b (h w) c -> b c h w', h=H, w=W)
output = dict(
warped=warped,
mask=mask,
correspondence=flow2
)
return output
def convert_left_to_right(left_embed, disparity, left_image, random_ratio=None):
# Get the height, width, and channels from the left embedding
_, height, width = left_embed.shape
# Initialize tensors for right_embed, converted_right_image, and mask
# right_embed = torch.full_like(left_embed, 255)
# converted_right_image = torch.full_like(left_image, 255)
right_embed = torch.ones_like(left_embed)
converted_right_image = torch.ones_like(left_image)
mask = torch.ones((height, width), dtype=torch.uint8, device=left_embed.device)
# Round the disparity and convert to int
disparity_rounded = torch.round(disparity).squeeze(0).long() # [h, w]
# Loop through the image dimensions and apply the conversion
for y in range(height):
for x in range(width):
new_x = x - disparity_rounded[y, x]
if 0 <= new_x < width:# and disparity_rounded[y, x] > 0:
right_embed[:, y, new_x] = left_embed[:, y, x]
converted_right_image[:, y, new_x] = left_image[:, y, x]
mask[y, new_x] = 0 # Mark as valid in the mask
# print(f"Mask sum before: {mask.sum()}")
# Apply random mask if drop_ratio is set
if random_ratio is not None:
print(f"Random ratio: {random_ratio}")
# Create a random mask with values ranging from 0 (invalid) to 1 (valid)
random_mask = torch.bernoulli(torch.full((height, width), 1 - random_ratio, device=left_embed.device)).byte()
# Perform a logical AND operation with the mask from the function
mask = mask | random_mask
# Apply the final mask to right_embed, converted_right_image, and disparity
right_embed[:, mask == 1] = 255 # Set masked out locations to 255 in the right embed
converted_right_image[:, mask == 1] = 255 # Set masked out locations to 255 in the converted right image
disparity[:, mask == 1] = 0 # Set masked out locations in the original disparity to 0
# print(f"Mask sum after: {mask.sum()}")
return right_embed, mask, converted_right_image, disparity
def convert_left_to_right_torch(left_embed, disparity, left_image, random_ratio=None):
"""
Convert left features to right features based on disparity values.
Args:
left_embed (torch.Tensor): [c, h, w] tensor representing left feature embeddings.
disparity (torch.Tensor): [1, h, w] tensor of disparity values.
left_image (torch.Tensor): [3, h, w] tensor representing the left image.
Returns:
right_embed (torch.Tensor): [c, h, w] tensor for the right feature embeddings.
mask (torch.Tensor): [h, w] binary mask (1 = invalid, 0 = valid).
converted_right_image (torch.Tensor): [3, h, w] tensor for the right image.
disparity (torch.Tensor): [1, h, w] tensor for the disparity.
"""
# Get the height, width, and channels from the left embedding
_, height, width = left_embed.shape
# Initialize tensors for right_embed, converted_right_image, and mask
right_embed = torch.zeros_like(left_embed)
# converted_right_image = torch.zeros_like(left_image)
converted_right_image = -torch.ones_like(left_image)
mask = torch.ones((height, width), dtype=torch.uint8, device=left_embed.device)
# Round the disparity and convert to int
disparity_rounded = torch.round(disparity).squeeze(0).long() # [h, w]
# Iterate over width and process each column for all rows
for x in range(width):
new_x = x - disparity_rounded[:, x]
valid_indices = (new_x >= 0) & (new_x < width) #& (disparity_rounded[:, x] > 0)
valid_new_x = new_x[valid_indices]
valid_y = torch.arange(height, device=left_embed.device)[valid_indices]
right_embed[:, valid_y, valid_new_x] = left_embed[:, valid_y, x]
converted_right_image[:, valid_y, valid_new_x] = left_image[:, valid_y, x]
mask[valid_y, valid_new_x] = 0 # Mark as valid in the mask
return right_embed, mask, converted_right_image, disparity