seungminkwak's picture
reset: clean history (purge leaked token)
08b23ce
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from third_partys.Video_Depth_Anything.video_depth_anything.video_depth import VideoDepthAnything
import torch
import torch.nn as nn
import numpy as np
import igl
import cv2
import time
import torch.nn.functional as F
from utils.quat_utils import quat_inverse, quat_log, quat_multiply, normalize_quaternion
from pytorch3d.structures import join_meshes_as_scene, join_meshes_as_batch
import os
from pathlib import Path
class DepthModule:
def __init__(self, encoder='vitl', device='cuda', input_size=518, fp32=False):
"""
Initialize the depth loss module with Video Depth Anything
Args:
encoder: 'vitl' or 'vits'
device: device to run the model on
input_size: input size for the model
fp32: whether to use float32 for inference
"""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.input_size = input_size
self.fp32 = fp32
# Initialize model configuration
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
# Load Video Depth Anything model
self.video_depth_model = VideoDepthAnything(**model_configs[encoder])
self.video_depth_model.load_state_dict(
torch.load(f'./third_partys/Video_Depth_Anything/ckpt/video_depth_anything_{encoder}.pth', map_location='cpu'),
strict=True
)
self.video_depth_model = self.video_depth_model.to(self.device).eval()
for param in self.video_depth_model.parameters():
param.requires_grad = False
def get_depth_maps(self, frames, target_fps=30):
"""
Get depth maps for video frames
"""
depths, _ = self.video_depth_model.infer_video_depth(
frames,
target_fps,
input_size=self.input_size,
device=self.device,
fp32=self.fp32
)
return depths
def save_depth_as_images(depth_np, output_dir='./depth_images'):
"""save depth images"""
os.makedirs(output_dir, exist_ok=True)
for i, depth_map in enumerate(depth_np):
depth_map = depth_map.detach().cpu().numpy()
valid_mask = (depth_map > 0)
if not valid_mask.any():
continue
valid_min = depth_map[valid_mask].min()
valid_max = depth_map[valid_mask].max()
normalized = np.zeros_like(depth_map)
normalized[valid_mask] = 255.0 * (depth_map[valid_mask] - valid_min) / (valid_max - valid_min)
depth_img = normalized.astype(np.uint8)
cv2.imwrite(os.path.join(output_dir, f'depth_{i:04d}.png'), depth_img)
print(f"Save {len(depth_np)} depth images to {output_dir}")
def compute_visibility_mask_igl(ray_origins, ray_dirs, distances, verts, faces, distance_tolerance=1e-6, for_vertices=False):
"""
Compute visibility mask using IGL ray-mesh intersection.
"""
num_rays = ray_origins.shape[0]
visibility_mask = np.ones(num_rays, dtype=bool)
for i in range(num_rays):
ray_origin = ray_origins[i].reshape(1, 3)
ray_dir = ray_dirs[i].reshape(1, 3)
intersections = igl.ray_mesh_intersect(ray_origin, ray_dir, verts, faces)
if intersections:
# Count intersections that occur before the target point
count = sum(1 for h in intersections if h[4] < distances[i] - distance_tolerance)
# count=0 → ray completely missed the mesh; count=1 → ray stops exactly at the face containing the joint
# count>1 → ray was blocked by other faces along the way
if for_vertices:
if count != 1:
visibility_mask[i] = False
else: # for joints
if count > 2:
visibility_mask[i] = False
return visibility_mask
def compute_reprojection_loss(renderer, vis_mask, predicted_joints, tracked_joints_2d, image_size):
"""
Compute reprojection loss between predicted 3D points and tracked 2D points.
"""
if predicted_joints.dim() != 3:
raise ValueError(f"predicted_joints must be 3D tensor, got shape {predicted_joints.shape}")
B, J, _ = predicted_joints.shape
device = predicted_joints.device
# Project 3D joints to 2D screen coordinates
projected = renderer.camera.transform_points_screen(
predicted_joints,
image_size=[image_size, image_size]
)
projected_2d = projected[..., :2] # (B, J, 2)
# Convert and validate tracked joints
if not isinstance(tracked_joints_2d, torch.Tensor):
tracked_joints_2d = torch.from_numpy(tracked_joints_2d).float()
tracked_joints_2d = tracked_joints_2d.to(device)
if tracked_joints_2d.dim() == 2:
tracked_joints_2d = tracked_joints_2d.unsqueeze(0).expand(B, -1, -1)
vis_mask = vis_mask.to(device).float()
num_visible = vis_mask.sum()
if num_visible == 0:
# No visible joints - return zero loss
return torch.tensor(0.0, device=device, requires_grad=True)
squared_diff = (projected_2d - tracked_joints_2d).pow(2).sum(dim=-1) # (B, J)
vis_mask_expanded = vis_mask.unsqueeze(0) # (1, J)
masked_loss = squared_diff * vis_mask_expanded # (B, J)
per_frame_loss = masked_loss.sum(dim=1) / num_visible # (B,)
final_loss = per_frame_loss.mean() # scalar
return final_loss
def geodesic_loss(q1, q2, eps=1e-6):
"""
Compute geodesic distance loss between batches of quaternions for rot smooth loss.
"""
q1_norm = normalize_quaternion(q1, eps=eps)
q2_norm = normalize_quaternion(q2, eps=eps)
dot_product = (q1_norm * q2_norm).sum(dim=-1, keepdim=True)
q2_corrected = torch.where(dot_product < 0, -q2_norm, q2_norm)
inner_product = (q1_norm * q2_corrected).sum(dim=-1)
# Clamp to valid range for arccos to avoid numerical issues
inner_product_clamped = torch.clamp(inner_product, min=-1.0 + eps, max=1.0 - eps)
theta = 2.0 * torch.arccos(torch.abs(inner_product_clamped))
return theta
def root_motion_reg(root_quats, root_pos):
return ((root_pos[1:] - root_pos[:-1])**2).mean(), (geodesic_loss(root_quats[1:], root_quats[:-1])**2).mean()
def joint_motion_coherence(quats_normed, parent_idx):
"""
Compute joint motion coherence loss to enforce smooth relative motion between parent-child joints.
"""
coherence_loss = 0
for j, parent in enumerate(parent_idx):
if parent != -1: # Skip root joint
parent_rot = quats_normed[:, parent] # (T, 4)
child_rot = quats_normed[:, j] # (T, 4)
# Compute relative rotation of child w.r.t. parent's local frame
# local_rot = parent_rot^(-1) * child_rot
local_rot = quat_multiply(quat_inverse(parent_rot), child_rot)
local_rot_velocity = local_rot[1:] - local_rot[:-1] # (T-1, 4)
coherence_loss += local_rot_velocity.pow(2).mean()
return coherence_loss
def read_flo_file(file_path):
"""
Read optical flow from .flo format file.
"""
with open(file_path, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if len(magic) == 0 or magic[0] != 202021.25:
raise ValueError(f'Invalid .flo file format: magic number {magic}')
w = np.fromfile(f, np.int32, count=1)[0]
h = np.fromfile(f, np.int32, count=1)[0]
data = np.fromfile(f, np.float32, count=2*w*h)
flow = data.reshape(h, w, 2)
return flow
def load_optical_flows(flow_dir, num_frames):
"""
Load sequence of optical flow files.
"""
flow_dir = Path(flow_dir)
flows = []
for i in range(num_frames - 1):
flow_path = flow_dir / f'flow_{i:04d}.flo'
if flow_path.exists():
flow = read_flo_file(flow_path)
flows.append(flow)
else:
raise ValueError("No flow files found")
return np.stack(flows, axis=0)
def rasterize_vertex_flow(flow_vertices, meshes, faces, image_size, renderer, eps = 1e-8):
"""
Rasterize per-vertex flow to dense flow field using barycentric interpolation.
"""
B, V, _ = flow_vertices.shape
device = flow_vertices.device
if isinstance(image_size, int):
H = W = image_size
else:
H, W = image_size
batch_meshes = join_meshes_as_batch([join_meshes_as_scene(m) for m in meshes]).to(device)
fragments = renderer.renderer.rasterizer(batch_meshes)
pix_to_face = fragments.pix_to_face # (B, H, W, K)
bary_coords = fragments.bary_coords # (B, H, W, K, 3)
flow_scene_list = []
for mesh_idx in range(B):
mesh = meshes[mesh_idx]
V_mesh = mesh.verts_packed().shape[0]
if V_mesh > flow_vertices.shape[1]:
raise ValueError(f"Mesh {mesh_idx} has {V_mesh} vertices but flow has {flow_vertices.shape[1]}")
flow_scene_list.append(flow_vertices[mesh_idx, :V_mesh])
flow_vertices_scene = torch.cat(flow_scene_list, dim=0).to(device)
faces_scene = batch_meshes.faces_packed()
flow_pred = torch.zeros(B, H, W, 2, device=device)
valid = pix_to_face[..., 0] >= 0
for b in range(B):
b_valid = valid[b] # (H,W)
if torch.count_nonzero(b_valid) == 0:
print(f"No valid pixels found for batch {b}")
continue
valid_indices = torch.nonzero(b_valid, as_tuple=True)
h_indices, w_indices = valid_indices
face_idxs = pix_to_face[b, h_indices, w_indices, 0] # (N,)
bary = bary_coords[b, h_indices, w_indices, 0] # (N,3)
max_face_idx = faces_scene.shape[0] - 1
if face_idxs.max() > max_face_idx:
raise RuntimeError(f"Face index {face_idxs.max()} exceeds max {max_face_idx}")
face_verts = faces_scene[face_idxs] # (N, 3)
f0, f1, f2 = face_verts.unbind(-1) # Each (N,)
max_vert_idx = flow_vertices_scene.shape[0] - 1
if max(f0.max(), f1.max(), f2.max()) > max_vert_idx:
raise RuntimeError(f"Vertex index exceeds flow_vertices_scene size {max_vert_idx}")
v0_flow = flow_vertices_scene[f0] # (N, 2)
v1_flow = flow_vertices_scene[f1] # (N, 2)
v2_flow = flow_vertices_scene[f2] # (N, 2)
# Interpolate using barycentric coordinates
b0, b1, b2 = bary.unbind(-1) # Each (N,)
# Ensure barycentric coordinates sum to 1 (numerical stability)
bary_sum = b0 + b1 + b2
b0 = b0 / (bary_sum + eps)
b1 = b1 / (bary_sum + eps)
b2 = b2 / (bary_sum + eps)
flow_interpolated = (
b0.unsqueeze(-1) * v0_flow +
b1.unsqueeze(-1) * v1_flow +
b2.unsqueeze(-1) * v2_flow
) # (N, 2)
# Update flow prediction
flow_pred[b, h_indices, w_indices] = flow_interpolated
return flow_pred
def calculate_flow_loss(flow_dir, device, mask, renderer, model):
"""
Calculate optical flow loss with improved error handling and flexibility.
"""
if device is None:
device = mask.device
T = mask.shape[0]
H, W = mask.shape[1:3]
if mask.shape[0] == T:
flow_mask = mask[1:] # Use frames 1 to T-1
else:
flow_mask = mask
flows_np = load_optical_flows(flow_dir, T)
flow_gt = torch.from_numpy(flows_np).float().to(device) # [T-1, H, W, 2]
vertices = model.deformed_vertices[0] # (T,V,3)
# Project vertices to get 2D flow
proj_t = renderer.project_points(vertices[:-1]) # (T-1,V,2) in pixels
proj_tp = renderer.project_points(vertices[1:])
vertex_flow = proj_tp - proj_t # (T-1,V,2) Δx,Δy
meshes = [model.get_mesh(t) for t in range(T)]
flow_pred = rasterize_vertex_flow(vertex_flow, meshes, model.faces[0], (H,W), renderer) # (B,H,W,2)
eps = 1e-3
diff = (flow_pred - flow_gt) * flow_mask.unsqueeze(-1) # (T-1, H, W, 2)
loss = torch.sqrt(diff.pow(2).sum(dim=-1) + eps**2) # Charbonnier loss
loss = loss.sum() / (flow_mask.sum() + 1e-6)
return loss
def normalize_depth_from_reference(depth_maps, reference_idx=0, invalid_value=-1.0, invert=False, eps = 1e-8):
"""
Normalize depth maps based on a reference frame with improved robustness.
"""
if depth_maps.dim() != 3:
raise ValueError(f"Expected depth_maps with 3 dimensions, got {depth_maps.dim()}")
T, H, W = depth_maps.shape
device = depth_maps.device
reference_depth = depth_maps[reference_idx]
valid_mask = (
(reference_depth != invalid_value) &
(reference_depth > 1e-8) & # Avoid very small positive values
torch.isfinite(reference_depth) # Exclude inf/nan
)
valid_values = reference_depth[valid_mask]
min_depth = torch.quantile(valid_values, 0.01) # 1st percentile
max_depth = torch.quantile(valid_values, 0.99) # 99th percentile
depth_range = max_depth - min_depth
if depth_range < eps:
logger.warning(f"Very small depth range ({depth_range:.6f}), using fallback normalization")
min_depth = valid_values.min()
max_depth = valid_values.max()
depth_range = max(max_depth - min_depth, eps)
scale = 1.0 / (max_depth - min_depth)
offset = -min_depth * scale
all_valid_mask = (
(depth_maps != invalid_value) &
(depth_maps > eps) &
torch.isfinite(depth_maps)
)
normalized_depths = torch.full_like(depth_maps, invalid_value)
if all_valid_mask.any():
normalized_values = depth_maps[all_valid_mask] * scale + offset
if invert:
normalized_values = 1.0 - normalized_values
normalized_depths[all_valid_mask] = normalized_values
return normalized_depths, scale.item(), offset.item()
def compute_depth_loss_normalized(mono_depths, zbuf_depths, mask):
"""
Compute normalized depth loss.
"""
device = zbuf_depths.device
# Normalize both depth types
zbuf_norm, z_scale, z_offset = normalize_depth_from_reference(zbuf_depths)
mono_norm, m_scale, m_offset = normalize_depth_from_reference(mono_depths, invert=True)
valid_zbuf = (zbuf_norm >= 0) & (zbuf_norm <= 1)
valid_mono = (mono_norm >= 0) & (mono_norm <= 1)
if mask.dtype != torch.bool:
mask = mask > 0.5
combined_mask = mask & valid_zbuf & valid_mono
num_valid = combined_mask.sum().item()
if num_valid == 0:
print("No valid pixels for depth loss computation")
return torch.tensor(0.0, device=device, requires_grad=True)
depth_diff = (zbuf_norm - mono_norm) * combined_mask.float()
loss = (depth_diff**2).sum() / num_valid
return loss