roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from cosmos_predict1.diffusion.inference.forward_warp_utils_pytorch import (
forward_warp,
reliable_depth_mask_range_batch,
unproject_points,
)
from cosmos_predict1.diffusion.inference.camera_utils import align_depth
class Cache3D_Base:
def __init__(
self,
input_image,
input_depth,
input_w2c,
input_intrinsics,
input_mask=None,
input_format=None,
input_points=None,
weight_dtype=torch.float32,
is_depth=True,
device="cuda",
filter_points_threshold=1.0,
foreground_masking=False,
):
"""
input_image: Tensor with varying dimensions.
input_format: List of dimension labels corresponding to input_image's dimensions.
E.g., ['B', 'C', 'H', 'W'], ['B', 'F', 'C', 'H', 'W'], etc.
"""
self.weight_dtype = weight_dtype
self.is_depth = is_depth
self.device = device
self.filter_points_threshold = filter_points_threshold
self.foreground_masking = foreground_masking
if input_format is None:
assert input_image.dim() == 4
input_format = ["B", "C", "H", "W"]
# Map dimension names to their indices in input_image
format_to_indices = {dim: idx for idx, dim in enumerate(input_format)}
input_shape = input_image.shape
if input_mask is not None:
input_image = torch.cat([input_image, input_mask], dim=format_to_indices.get("C"))
# B (batch size), F (frame count), N dimensions: no aggregation during warping.
# Only broadcasting over F to match the target w2c.
# V: aggregate via concatenation or duster
B = input_shape[format_to_indices.get("B", 0)] if "B" in format_to_indices else 1 # batch
F = input_shape[format_to_indices.get("F", 0)] if "F" in format_to_indices else 1 # frame
N = input_shape[format_to_indices.get("N", 0)] if "N" in format_to_indices else 1 # buffer
V = input_shape[format_to_indices.get("V", 0)] if "V" in format_to_indices else 1 # view
H = input_shape[format_to_indices.get("H", 0)] if "H" in format_to_indices else None
W = input_shape[format_to_indices.get("W", 0)] if "W" in format_to_indices else None
# Desired dimension order
desired_dims = ["B", "F", "N", "V", "C", "H", "W"]
# Build permute order based on input_format
permute_order = []
for dim in desired_dims:
idx = format_to_indices.get(dim)
if idx is not None:
permute_order.append(idx)
else:
# Placeholder for dimensions to be added later
permute_order.append(None)
# Remove None values for permute operation
permute_indices = [idx for idx in permute_order if idx is not None]
input_image = input_image.permute(*permute_indices)
# Insert dimensions of size 1 where necessary
for i, idx in enumerate(permute_order):
if idx is None:
input_image = input_image.unsqueeze(i)
# Now input_image has the shape B x F x N x V x C x H x W
if input_mask is not None:
self.input_image, self.input_mask = input_image[:, :, :, :, :3], input_image[:, :, :, :, 3:]
self.input_mask = self.input_mask.to("cpu")
else:
self.input_mask = None
self.input_image = input_image
self.input_image = self.input_image.to(weight_dtype).to("cpu")
if input_points is not None:
self.input_points = input_points.reshape(B, F, N, V, H, W, 3).to("cpu")
self.input_depth = None
else:
input_depth = torch.nan_to_num(input_depth, nan=100)
input_depth = torch.clamp(input_depth, min=0, max=100)
if weight_dtype == torch.float16:
input_depth = torch.clamp(input_depth, max=70)
self.input_points = (
self._compute_input_points(
input_depth.reshape(-1, 1, H, W),
input_w2c.reshape(-1, 4, 4),
input_intrinsics.reshape(-1, 3, 3),
)
.to(weight_dtype)
.reshape(B, F, N, V, H, W, 3)
.to("cpu")
)
self.input_depth = input_depth
if self.filter_points_threshold < 1.0 and input_depth is not None:
input_depth = input_depth.reshape(-1, 1, H, W)
depth_mask = reliable_depth_mask_range_batch(input_depth, ratio_thresh=self.filter_points_threshold).reshape(B, F, N, V, 1, H, W)
if self.input_mask is None:
self.input_mask = depth_mask.to("cpu")
else:
self.input_mask = self.input_mask * depth_mask.to(self.input_mask.device)
self.boundary_mask = None
if foreground_masking:
input_depth = input_depth.reshape(-1, 1, H, W)
depth_mask = reliable_depth_mask_range_batch(input_depth)
self.boundary_mask = (~depth_mask).reshape(B, F, N, V, 1, H, W).to("cpu")
def _compute_input_points(self, input_depth, input_w2c, input_intrinsics):
input_points = unproject_points(
input_depth,
input_w2c,
input_intrinsics,
is_depth=self.is_depth,
)
return input_points
def update_cache(self):
raise NotImplementedError
def input_frame_count(self) -> int:
return self.input_image.shape[1]
def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0):
bs, F_target, _, _ = target_w2cs.shape
B, F, N, V, C, H, W = self.input_image.shape
assert bs == B
target_w2cs = target_w2cs.reshape(B, F_target, 1, 4, 4).expand(B, F_target, N, 4, 4).reshape(-1, 4, 4)
target_intrinsics = (
target_intrinsics.reshape(B, F_target, 1, 3, 3).expand(B, F_target, N, 3, 3).reshape(-1, 3, 3)
)
first_images = rearrange(self.input_image[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, C, H, W), "B F N V C H W-> (B F N) V C H W").to(self.device)
first_points = rearrange(
self.input_points[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, H, W, 3), "B F N V H W C-> (B F N) V H W C"
).to(self.device)
first_masks = rearrange(
self.input_mask[:, start_frame_idx:start_frame_idx+F_target].expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W"
).to(self.device) if self.input_mask is not None else None
boundary_masks = rearrange(
self.boundary_mask.expand(B, F_target, N, V, 1, H, W), "B F N V C H W-> (B F N) V C H W"
) if self.boundary_mask is not None else None
if first_images.shape[1] == 1:
warp_chunk_size = 2
rendered_warp_images = []
rendered_warp_masks = []
rendered_warp_depth = []
rendered_warped_flows = []
first_images = first_images.squeeze(1)
first_points = first_points.squeeze(1)
first_masks = first_masks.squeeze(1) if first_masks is not None else None
for i in range(0, first_images.shape[0], warp_chunk_size):
(
rendered_warp_images_chunk,
rendered_warp_masks_chunk,
rendered_warp_depth_chunk,
rendered_warped_flows_chunk,
) = forward_warp(
first_images[i : i + warp_chunk_size],
mask1=first_masks[i : i + warp_chunk_size] if first_masks is not None else None,
depth1=None,
transformation1=None,
transformation2=target_w2cs[i : i + warp_chunk_size],
intrinsic1=target_intrinsics[i : i + warp_chunk_size],
intrinsic2=target_intrinsics[i : i + warp_chunk_size],
render_depth=render_depth,
world_points1=first_points[i : i + warp_chunk_size],
foreground_masking=self.foreground_masking,
boundary_mask=boundary_masks[i : i + warp_chunk_size, 0, 0] if boundary_masks is not None else None
)
rendered_warp_images.append(rendered_warp_images_chunk)
rendered_warp_masks.append(rendered_warp_masks_chunk)
rendered_warp_depth.append(rendered_warp_depth_chunk)
rendered_warped_flows.append(rendered_warped_flows_chunk)
rendered_warp_images = torch.cat(rendered_warp_images, dim=0)
rendered_warp_masks = torch.cat(rendered_warp_masks, dim=0)
if render_depth:
rendered_warp_depth = torch.cat(rendered_warp_depth, dim=0)
rendered_warped_flows = torch.cat(rendered_warped_flows, dim=0)
else:
raise NotImplementedError
pixels = rearrange(rendered_warp_images, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
masks = rearrange(rendered_warp_masks, "(b f n) c h w -> b f n c h w", b=bs, f=F_target, n=N)
if render_depth:
pixels = rearrange(rendered_warp_depth, "(b f n) h w -> b f n h w", b=bs, f=F_target, n=N)
return pixels, masks
class Cache3D_Buffer(Cache3D_Base):
def __init__(self, frame_buffer_max=0, noise_aug_strength=0, generator=None, **kwargs):
super().__init__(**kwargs)
self.frame_buffer_max = frame_buffer_max
self.noise_aug_strength = noise_aug_strength
self.generator = generator
def update_cache(self, new_image, new_depth, new_w2c, new_mask=None, new_intrinsics=None, depth_alignment=True, alignment_method="non_rigid"): # 3D cache
new_image = new_image.to(self.weight_dtype).to(self.device)
new_depth = new_depth.to(self.weight_dtype).to(self.device)
new_w2c = new_w2c.to(self.weight_dtype).to(self.device)
if new_intrinsics is not None:
new_intrinsics = new_intrinsics.to(self.weight_dtype).to(self.device)
new_depth = torch.nan_to_num(new_depth, nan=1e4)
new_depth = torch.clamp(new_depth, min=0, max=1e4)
if depth_alignment:
target_depth, target_mask = self.render_cache(
new_w2c.unsqueeze(1), new_intrinsics.unsqueeze(1), render_depth=True
)
target_depth, target_mask = target_depth[:, :, 0], target_mask[:, :, 0]
if alignment_method == "rigid":
new_depth = (
align_depth(
new_depth.squeeze(),
target_depth.squeeze(),
target_mask.bool().squeeze(),
)
.reshape_as(new_depth)
.detach()
)
elif alignment_method == "non_rigid":
with torch.enable_grad():
new_depth = (
align_depth(
new_depth.squeeze(),
target_depth.squeeze(),
target_mask.bool().squeeze(),
k=new_intrinsics.squeeze(),
c2w=torch.inverse(new_w2c.squeeze()),
alignment_method="non_rigid",
num_iters=100,
lambda_arap=0.1,
smoothing_kernel_size=3,
)
.reshape_as(new_depth)
.detach()
)
else:
raise NotImplementedError
new_points = unproject_points(new_depth, new_w2c, new_intrinsics, is_depth=self.is_depth).cpu()
new_image = new_image.cpu()
if self.filter_points_threshold < 1.0:
B, F, N, V, C, H, W = self.input_image.shape
new_depth = new_depth.reshape(-1, 1, H, W)
depth_mask = reliable_depth_mask_range_batch(new_depth, ratio_thresh=self.filter_points_threshold).reshape(B, 1, H, W)
if new_mask is None:
new_mask = depth_mask.to("cpu")
else:
new_mask = new_mask * depth_mask.to(new_mask.device)
if new_mask is not None:
new_mask = new_mask.cpu()
if self.frame_buffer_max > 1: # newest frame first
if self.input_image.shape[2] < self.frame_buffer_max:
self.input_image = torch.cat([new_image[:, None, None, None], self.input_image], 2)
self.input_points = torch.cat([new_points[:, None, None, None], self.input_points], 2)
if self.input_mask is not None:
self.input_mask = torch.cat([new_mask[:, None, None, None], self.input_mask], 2)
else:
self.input_image[:, :, 0] = new_image[:, None, None]
self.input_points[:, :, 0] = new_points[:, None, None]
if self.input_mask is not None:
self.input_mask[:, :, 0] = new_mask[:, None, None]
else:
self.input_image = new_image[:, None, None, None]
self.input_points = new_points[:, None, None, None]
def render_cache(
self,
target_w2cs,
target_intrinsics,
render_depth: bool = False,
start_frame_idx: int = 0, # For consistency with Cache4D
):
assert start_frame_idx == 0, "start_frame_idx must be 0 for Cache3D_Buffer"
output_device = target_w2cs.device
target_w2cs = target_w2cs.to(self.weight_dtype).to(self.device)
target_intrinsics = target_intrinsics.to(self.weight_dtype).to(self.device)
pixels, masks = super().render_cache(
target_w2cs, target_intrinsics, render_depth
)
if not render_depth:
noise = torch.randn(pixels.shape, generator=self.generator, device=pixels.device, dtype=pixels.dtype)
per_buffer_noise = (
torch.arange(start=pixels.shape[2] - 1, end=-1, step=-1, device=pixels.device)
* self.noise_aug_strength
)
pixels = pixels + noise * per_buffer_noise.reshape(1, 1, -1, 1, 1, 1) # B, F, N, C, H, W
return pixels.to(output_device), masks.to(output_device)
class Cache4D(Cache3D_Base):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def update_cache(self, **kwargs):
raise NotImplementedError
def render_cache(self, target_w2cs, target_intrinsics, render_depth=False, start_frame_idx=0):
rendered_warp_images, rendered_warp_masks = super().render_cache(target_w2cs, target_intrinsics, render_depth, start_frame_idx)
return rendered_warp_images, rendered_warp_masks