|
|
|
|
|
""" |
|
|
SAM 3D Body (3DB) Mesh Alignment Utilities |
|
|
Handles alignment of 3DB meshes to SAM 3D Object, same as MoGe point cloud scale. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import json |
|
|
import numpy as np |
|
|
import torch |
|
|
import trimesh |
|
|
from PIL import Image |
|
|
import torch.nn.functional as F |
|
|
from pytorch3d.structures import Meshes |
|
|
from pytorch3d.renderer import PerspectiveCameras, RasterizationSettings, MeshRasterizer, TexturesVertex |
|
|
from moge.model.v1 import MoGeModel |
|
|
|
|
|
|
|
|
def load_3db_mesh(mesh_path, device='cuda'): |
|
|
"""Load 3DB mesh and convert from OpenGL to PyTorch3D coordinates.""" |
|
|
mesh = trimesh.load(mesh_path) |
|
|
vertices = np.array(mesh.vertices) |
|
|
faces = np.array(mesh.faces) |
|
|
|
|
|
|
|
|
vertices[:, 0] *= -1 |
|
|
vertices[:, 2] *= -1 |
|
|
|
|
|
vertices = torch.from_numpy(vertices).float().to(device) |
|
|
faces = torch.from_numpy(faces).long().to(device) |
|
|
return vertices, faces |
|
|
|
|
|
|
|
|
def get_moge_pointcloud(image_tensor, device='cuda'): |
|
|
"""Generate MoGe point cloud from image tensor.""" |
|
|
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(device) |
|
|
moge_model.eval() |
|
|
with torch.no_grad(): |
|
|
moge_output = moge_model.infer(image_tensor) |
|
|
return moge_output |
|
|
|
|
|
|
|
|
def denormalize_intrinsics(norm_K, height, width): |
|
|
"""Convert normalized intrinsics to absolute pixel coordinates.""" |
|
|
cx_norm, cy_norm = norm_K[0, 2], norm_K[1, 2] |
|
|
fx_norm, fy_norm = norm_K[0, 0], norm_K[1, 1] |
|
|
|
|
|
fx_abs = fx_norm * width |
|
|
fy_abs = fy_norm * height |
|
|
cx_abs = cx_norm * width |
|
|
cy_abs = cy_norm * height |
|
|
fx_abs = fy_abs |
|
|
|
|
|
return np.array([ |
|
|
[fx_abs, 0.0, cx_abs], |
|
|
[0.0, fy_abs, cy_abs], |
|
|
[0.0, 0.0, 1.0] |
|
|
]) |
|
|
|
|
|
|
|
|
def crop_mesh_with_mask(vertices, faces, focal_length, mask, device='cuda'): |
|
|
"""Crop mesh vertices to only those visible in the mask.""" |
|
|
textures = TexturesVertex(verts_features=torch.ones_like(vertices)[None]) |
|
|
mesh = Meshes(verts=[vertices], faces=[faces], textures=textures) |
|
|
|
|
|
H, W = mask.shape[-2:] |
|
|
fx = fy = focal_length |
|
|
cx, cy = W / 2.0, H / 2.0 |
|
|
|
|
|
camera = PerspectiveCameras( |
|
|
focal_length=((fx, fy),), |
|
|
principal_point=((cx, cy),), |
|
|
image_size=((H, W),), |
|
|
in_ndc=False, device=device |
|
|
) |
|
|
|
|
|
raster_settings = RasterizationSettings( |
|
|
image_size=(H, W), blur_radius=0.0, faces_per_pixel=1, |
|
|
cull_backfaces=False, bin_size=0, |
|
|
) |
|
|
|
|
|
rasterizer = MeshRasterizer(cameras=camera, raster_settings=raster_settings) |
|
|
fragments = rasterizer(mesh) |
|
|
|
|
|
face_indices = fragments.pix_to_face[0, ..., 0] |
|
|
visible_mask = (mask > 0) & (face_indices >= 0) |
|
|
visible_face_ids = face_indices[visible_mask] |
|
|
|
|
|
visible_faces = faces[visible_face_ids] |
|
|
visible_vert_ids = torch.unique(visible_faces) |
|
|
verts_cropped = vertices[visible_vert_ids] |
|
|
|
|
|
return verts_cropped, visible_mask |
|
|
|
|
|
|
|
|
def extract_target_points(pointmap, visible_mask): |
|
|
"""Extract target points from MoGe pointmap using visible mask.""" |
|
|
target_points = pointmap[visible_mask.bool()] |
|
|
|
|
|
|
|
|
target_points[:, 0] *= -1 |
|
|
target_points[:, 1] *= -1 |
|
|
|
|
|
|
|
|
z_range = torch.max(target_points[:, 2]) - torch.min(target_points[:, 2]) |
|
|
if z_range > 6.0: |
|
|
thresh = 0.90 |
|
|
elif z_range > 2.0: |
|
|
thresh = 0.93 |
|
|
else: |
|
|
thresh = 0.95 |
|
|
|
|
|
depth_quantile = torch.quantile(target_points[:, 2], thresh) |
|
|
target_points = target_points[target_points[:, 2] <= depth_quantile] |
|
|
|
|
|
|
|
|
finite_mask = torch.isfinite(target_points).all(dim=1) |
|
|
target_points = target_points[finite_mask] |
|
|
|
|
|
return target_points |
|
|
|
|
|
|
|
|
def align_mesh_to_pointcloud(vertices, target_points): |
|
|
"""Align mesh vertices to target point cloud using scale and translation.""" |
|
|
if target_points.shape[0] == 0: |
|
|
print("[WARNING] No target points for alignment!") |
|
|
return vertices, torch.tensor(1.0), torch.zeros(3) |
|
|
|
|
|
|
|
|
height_src = torch.max(vertices[:, 1]) - torch.min(vertices[:, 1]) |
|
|
height_tgt = torch.max(target_points[:, 1]) - torch.min(target_points[:, 1]) |
|
|
scale_factor = height_tgt / height_src |
|
|
|
|
|
vertices_scaled = vertices * scale_factor |
|
|
|
|
|
|
|
|
center_src = torch.mean(vertices_scaled, dim=0) |
|
|
center_tgt = torch.mean(target_points, dim=0) |
|
|
translation = center_tgt - center_src |
|
|
|
|
|
vertices_aligned = vertices_scaled + translation |
|
|
return vertices_aligned, scale_factor, translation |
|
|
|
|
|
|
|
|
def load_mask_for_alignment(mask_path): |
|
|
"""Load mask image as numpy array.""" |
|
|
mask = Image.open(mask_path).convert('L') |
|
|
mask_array = np.array(mask) / 255.0 |
|
|
return mask_array |
|
|
|
|
|
|
|
|
def load_focal_length_from_json(json_path): |
|
|
"""Load focal length from JSON file.""" |
|
|
try: |
|
|
with open(json_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
focal_length = data.get('focal_length') |
|
|
if focal_length is None: |
|
|
raise ValueError("'focal_length' key not found in JSON file") |
|
|
print(f"[INFO] Loaded focal length from {json_path}: {focal_length}") |
|
|
return focal_length |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Failed to load focal length from {json_path}: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
def process_3db_alignment(mesh_path, mask_path, image_path, device='cuda', focal_length_json_path=None): |
|
|
"""Complete pipeline for aligning 3DB mesh to MoGe scale.""" |
|
|
print(f"[INFO] Processing alignment...") |
|
|
|
|
|
|
|
|
vertices, faces = load_3db_mesh(mesh_path, device) |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
image_tensor = torch.from_numpy(np.array(image)).float().permute(2, 0, 1) / 255.0 |
|
|
image_tensor = image_tensor.to(device) |
|
|
|
|
|
|
|
|
H, W = image_tensor.shape[1:] |
|
|
mask = load_mask_for_alignment(mask_path) |
|
|
if mask.shape != (H, W): |
|
|
mask = Image.fromarray((mask * 255).astype(np.uint8)) |
|
|
mask = mask.resize((W, H), Image.NEAREST) |
|
|
mask = np.array(mask) / 255.0 |
|
|
mask = torch.from_numpy(mask).float().to(device) |
|
|
|
|
|
|
|
|
print("[INFO] Generating MoGe point cloud...") |
|
|
moge_output = get_moge_pointcloud(image_tensor, device) |
|
|
|
|
|
|
|
|
if focal_length_json_path is not None: |
|
|
focal_length = load_focal_length_from_json(focal_length_json_path) |
|
|
else: |
|
|
|
|
|
intrinsics = denormalize_intrinsics(moge_output['intrinsics'].cpu().numpy(), H, W) |
|
|
focal_length = intrinsics[1, 1] |
|
|
print(f"[INFO] Using computed focal length from MoGe: {focal_length}") |
|
|
|
|
|
|
|
|
print("[INFO] Cropping mesh with mask...") |
|
|
verts_cropped, visible_mask = crop_mesh_with_mask(vertices, faces, focal_length, mask, device) |
|
|
|
|
|
|
|
|
print("[INFO] Extracting target points...") |
|
|
target_points = extract_target_points(moge_output['points'], visible_mask) |
|
|
|
|
|
if target_points.shape[0] == 0: |
|
|
print("[ERROR] No valid target points found!") |
|
|
return None |
|
|
|
|
|
|
|
|
print("[INFO] Aligning mesh to point cloud...") |
|
|
aligned_vertices, scale_factor, translation = align_mesh_to_pointcloud(verts_cropped, target_points) |
|
|
|
|
|
|
|
|
full_aligned_vertices = (vertices * scale_factor) + translation |
|
|
|
|
|
|
|
|
final_vertices_opengl = full_aligned_vertices.cpu().numpy() |
|
|
final_vertices_opengl[:, 0] *= -1 |
|
|
final_vertices_opengl[:, 2] *= -1 |
|
|
|
|
|
results = { |
|
|
'aligned_vertices_opengl': final_vertices_opengl, |
|
|
'faces': faces.cpu().numpy(), |
|
|
'scale_factor': scale_factor.item(), |
|
|
'translation': translation.cpu().numpy(), |
|
|
'focal_length': focal_length, |
|
|
'target_points_count': target_points.shape[0], |
|
|
'cropped_vertices_count': verts_cropped.shape[0] |
|
|
} |
|
|
|
|
|
print(f"[INFO] Alignment completed - Scale: {scale_factor.item():.4f}, Target points: {target_points.shape[0]}") |
|
|
return results |
|
|
|
|
|
|
|
|
def process_and_save_alignment(mesh_path, mask_path, image_path, output_dir, device='cuda', focal_length_json_path=None): |
|
|
""" |
|
|
Complete pipeline for processing 3DB alignment and saving the result. |
|
|
|
|
|
Args: |
|
|
mesh_path: Path to input 3DB mesh (.ply) |
|
|
mask_path: Path to mask image (.png) |
|
|
image_path: Path to input image (.jpg) |
|
|
output_dir: Directory to save aligned mesh |
|
|
device: Device to use ('cuda' or 'cpu') |
|
|
focal_length_json_path: Optional path to focal length JSON file |
|
|
|
|
|
Returns: |
|
|
tuple: (success: bool, output_mesh_path: str or None, result_info: dict or None) |
|
|
""" |
|
|
try: |
|
|
print("[INFO] Starting 3DB mesh alignment pipeline...") |
|
|
|
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
result = process_3db_alignment( |
|
|
mesh_path=mesh_path, |
|
|
mask_path=mask_path, |
|
|
image_path=image_path, |
|
|
device=device, |
|
|
focal_length_json_path=focal_length_json_path |
|
|
) |
|
|
|
|
|
if result is not None: |
|
|
|
|
|
output_mesh_path = os.path.join(output_dir, 'human_aligned.ply') |
|
|
aligned_mesh = trimesh.Trimesh( |
|
|
vertices=result['aligned_vertices_opengl'], |
|
|
faces=result['faces'] |
|
|
) |
|
|
aligned_mesh.export(output_mesh_path) |
|
|
|
|
|
print(f" SUCCESS! Saved aligned mesh to: {output_mesh_path}") |
|
|
return True, output_mesh_path, result |
|
|
else: |
|
|
print(" ERROR: Failed to process mesh alignment") |
|
|
return False, None, None |
|
|
|
|
|
except Exception as e: |
|
|
print(f" ERROR: Exception during processing: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False, None, None |
|
|
|
|
|
finally: |
|
|
print(" Processing complete!") |
|
|
|
|
|
|
|
|
def visualize_meshes_interactive(aligned_mesh_path, dfy_mesh_path, output_dir=None, share=True, height=600): |
|
|
""" |
|
|
Interactive Gradio-based 3D visualization of aligned human and object meshes. |
|
|
|
|
|
Args: |
|
|
aligned_mesh_path: Path to aligned mesh PLY file |
|
|
dfy_mesh_path: Path to 3Dfy GLB file |
|
|
output_dir: Directory to save combined GLB file (defaults to same dir as aligned_mesh_path) |
|
|
share: Whether to create a public shareable link (default: True) |
|
|
height: Height of the 3D viewer in pixels (default: 600) |
|
|
|
|
|
Returns: |
|
|
tuple: (demo, combined_glb_path) - Gradio demo object and path to combined GLB file |
|
|
""" |
|
|
import gradio as gr |
|
|
|
|
|
print("Loading meshes for interactive visualization...") |
|
|
|
|
|
try: |
|
|
|
|
|
aligned_mesh = trimesh.load(aligned_mesh_path) |
|
|
print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") |
|
|
|
|
|
|
|
|
dfy_scene = trimesh.load(dfy_mesh_path) |
|
|
|
|
|
if hasattr(dfy_scene, 'dump'): |
|
|
dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] |
|
|
if len(dfy_meshes) == 1: |
|
|
dfy_mesh = dfy_meshes[0] |
|
|
elif len(dfy_meshes) > 1: |
|
|
dfy_mesh = trimesh.util.concatenate(dfy_meshes) |
|
|
else: |
|
|
raise ValueError("No valid meshes in GLB file") |
|
|
else: |
|
|
dfy_mesh = dfy_scene |
|
|
|
|
|
print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") |
|
|
|
|
|
|
|
|
scene = trimesh.Scene() |
|
|
|
|
|
|
|
|
aligned_copy = aligned_mesh.copy() |
|
|
aligned_copy.visual.vertex_colors = [255, 0, 0, 200] |
|
|
scene.add_geometry(aligned_copy, node_name="sam3d_aligned_human") |
|
|
|
|
|
dfy_copy = dfy_mesh.copy() |
|
|
dfy_copy.visual.vertex_colors = [0, 0, 255, 200] |
|
|
scene.add_geometry(dfy_copy, node_name="dfy_object") |
|
|
|
|
|
|
|
|
if output_dir is None: |
|
|
output_dir = os.path.dirname(aligned_mesh_path) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
combined_glb_path = os.path.join(output_dir, 'combined_scene.glb') |
|
|
scene.export(combined_glb_path) |
|
|
print(f"Exported combined scene to: {combined_glb_path}") |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 3D Mesh Alignment Visualization") |
|
|
gr.Markdown("**Red**: SAM 3D Body Aligned Human | **Blue**: 3Dfy Object") |
|
|
gr.Model3D( |
|
|
value=combined_glb_path, |
|
|
label="Combined 3D Scene (Interactive)", |
|
|
height=height |
|
|
) |
|
|
|
|
|
|
|
|
print("Launching interactive 3D viewer...") |
|
|
demo.launch(share=share) |
|
|
|
|
|
return demo, combined_glb_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR in visualization: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, None |
|
|
|
|
|
|
|
|
def visualize_meshes_comparison(aligned_mesh_path, dfy_mesh_path, use_interactive=False): |
|
|
""" |
|
|
Simple visualization of both meshes in a single 3D plot. |
|
|
|
|
|
DEPRECATED: Use visualize_meshes_interactive() for better interactive visualization. |
|
|
|
|
|
Args: |
|
|
aligned_mesh_path: Path to aligned mesh PLY file |
|
|
dfy_mesh_path: Path to 3Dfy GLB file |
|
|
use_interactive: Whether to attempt trimesh scene viewer (default: False) |
|
|
|
|
|
Returns: |
|
|
tuple: (aligned_mesh, dfy_mesh) trimesh objects or (None, None) if failed |
|
|
""" |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
print("Loading meshes for visualization...") |
|
|
|
|
|
try: |
|
|
|
|
|
aligned_mesh = trimesh.load(aligned_mesh_path) |
|
|
print(f"Loaded aligned mesh: {len(aligned_mesh.vertices)} vertices") |
|
|
|
|
|
|
|
|
dfy_scene = trimesh.load(dfy_mesh_path) |
|
|
|
|
|
if hasattr(dfy_scene, 'dump'): |
|
|
dfy_meshes = [geom for geom in dfy_scene.geometry.values() if hasattr(geom, 'vertices')] |
|
|
if len(dfy_meshes) == 1: |
|
|
dfy_mesh = dfy_meshes[0] |
|
|
elif len(dfy_meshes) > 1: |
|
|
dfy_mesh = trimesh.util.concatenate(dfy_meshes) |
|
|
else: |
|
|
raise ValueError("No valid meshes in GLB file") |
|
|
else: |
|
|
dfy_mesh = dfy_scene |
|
|
|
|
|
print(f"Loaded 3Dfy mesh: {len(dfy_mesh.vertices)} vertices") |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(12, 10)) |
|
|
ax = fig.add_subplot(111, projection='3d') |
|
|
|
|
|
|
|
|
ax.scatter(dfy_mesh.vertices[:, 0], |
|
|
dfy_mesh.vertices[:, 1], |
|
|
dfy_mesh.vertices[:, 2], |
|
|
c='blue', s=0.1, alpha=0.6, label='3Dfy Original') |
|
|
|
|
|
ax.scatter(aligned_mesh.vertices[:, 0], |
|
|
aligned_mesh.vertices[:, 1], |
|
|
aligned_mesh.vertices[:, 2], |
|
|
c='red', s=0.1, alpha=0.6, label='SAM 3D Body Aligned') |
|
|
|
|
|
ax.set_title('Mesh Comparison: 3Dfy vs SAM 3D Body Aligned', fontsize=16, fontweight='bold') |
|
|
ax.set_xlabel('X') |
|
|
ax.set_ylabel('Y') |
|
|
ax.set_zlabel('Z') |
|
|
ax.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
if use_interactive: |
|
|
try: |
|
|
print("Creating trimesh scene...") |
|
|
scene = trimesh.Scene() |
|
|
|
|
|
|
|
|
aligned_copy = aligned_mesh.copy() |
|
|
aligned_copy.visual.vertex_colors = [255, 0, 0, 200] |
|
|
scene.add_geometry(aligned_copy, node_name="sam3d_aligned") |
|
|
|
|
|
dfy_copy = dfy_mesh.copy() |
|
|
dfy_copy.visual.vertex_colors = [0, 0, 255, 200] |
|
|
scene.add_geometry(dfy_copy, node_name="dfy_original") |
|
|
|
|
|
print("Opening interactive trimesh viewer...") |
|
|
scene.show() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Trimesh viewer not available: {e}") |
|
|
|
|
|
print("Visualization complete") |
|
|
return aligned_mesh, dfy_mesh |
|
|
|
|
|
except Exception as e: |
|
|
print(f"ERROR in visualization: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, None |