particulate / infer_asset.py
Ruining Li
Fix
8ef473b
import numpy as np
from pathlib import Path
import torch
import trimesh
from particulate.visualization_utils import (
get_3D_arrow_on_points,
create_arrow,
create_ring,
create_textured_mesh_parts,
ARROW_COLOR_REVOLUTE,
ARROW_COLOR_PRISMATIC
)
from particulate.articulation_utils import plucker_to_axis_point
from particulate.export_utils import export_animated_glb_file
from partfield_utils import obtain_partfield_feats, get_partfield_model
DATA_CONFIG = {
'sharp_point_ratio': 0.5,
'normalize_points': True
}
def sharp_sample_pointcloud(mesh, num_points: int = 8192):
V = mesh.vertices
N = mesh.face_normals
F = mesh.faces
# Build edge-to-faces mapping
# Each edge is represented as (min_vertex_id, max_vertex_id) to ensure consistent ordering
edge_to_faces = {}
for face_idx in range(len(F)):
face = F[face_idx]
# Get the three edges of this face
edges = [
(face[0], face[1]),
(face[1], face[2]),
(face[2], face[0])
]
for edge in edges:
# Normalize edge ordering (min vertex first)
edge_key = tuple(sorted(edge))
if edge_key not in edge_to_faces:
edge_to_faces[edge_key] = []
edge_to_faces[edge_key].append(face_idx)
# Identify sharp edges based on face normal angles and store their averaged normals
sharp_edges = []
sharp_edge_normals = []
sharp_edge_faces = [] # Store adjacent faces for each sharp edge
cos_30 = np.cos(np.radians(30)) # ≈ 0.866
cos_150 = np.cos(np.radians(150)) # ≈ -0.866
for edge_key, face_indices in edge_to_faces.items():
# Check if edge has >= 2 faces
if len(face_indices) < 2:
continue
# Check all pairs of face normals
is_sharp = False
for i in range(len(face_indices)):
for j in range(i + 1, len(face_indices)):
n1 = N[face_indices[i]]
n2 = N[face_indices[j]]
dot_product = np.dot(n1, n2)
# Check if angle is between 30 and 150 degrees
if cos_150 < dot_product < cos_30 and np.linalg.norm(n1) > 1e-8 and np.linalg.norm(n2) > 1e-8:
is_sharp = True
sharp_edges.append(edge_key)
averaged_normal = (n1 + n2) / 2
sharp_edge_normals.append(averaged_normal)
sharp_edge_faces.append(face_indices) # Store all adjacent faces
break
if is_sharp:
break
# Convert sharp edges to vertex arrays
edge_a = np.array([edge[0] for edge in sharp_edges], dtype=np.int32)
edge_b = np.array([edge[1] for edge in sharp_edges], dtype=np.int32)
sharp_edge_normals = np.array(sharp_edge_normals, dtype=np.float64)
# Handle the case where there are no sharp edges
if len(sharp_edges) == 0:
# Return empty arrays with appropriate shape
samples = np.zeros((0, 3), dtype=np.float64)
normals = np.zeros((0, 3), dtype=np.float64)
edge_indices = np.zeros((0,), dtype=np.int32)
return samples, normals, edge_indices, sharp_edge_faces
sharp_verts_a = V[edge_a]
sharp_verts_b = V[edge_b]
weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1)
weights /= np.sum(weights)
random_number = np.random.rand(num_points)
w = np.random.rand(num_points, 1)
index = np.searchsorted(weights.cumsum(), random_number)
samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index]
normals = sharp_edge_normals[index] # Use the averaged face normal for each edge
return samples, normals, index, sharp_edge_faces
def sample_points(mesh, num_points, sharp_point_ratio, at_least_one_point_per_face=False):
"""Sample points from mesh using sharp edge and uniform sampling."""
num_points_sharp_edges = int(num_points * sharp_point_ratio)
num_points_uniform = num_points - num_points_sharp_edges
points_sharp, normals_sharp, edge_indices, sharp_edge_faces = sharp_sample_pointcloud(mesh, num_points_sharp_edges)
# If no sharp edges were found, sample all points uniformly
if len(points_sharp) == 0 and sharp_point_ratio > 0:
print(f"Warning: No sharp edges found, sampling all points uniformly")
num_points_uniform = num_points
if at_least_one_point_per_face:
num_faces = len(mesh.faces)
if num_points_uniform < num_faces:
raise ValueError(
"Unable to sample at least one point per face: "
f"{num_faces} faces > 51.2k"
)
# Get a random permutation of face indices
face_perm = np.random.permutation(num_faces)
# Sample one point from each face
points_per_face = []
for face_idx in face_perm:
# Sample one random point on this face using barycentric coordinates
r1, r2 = np.random.random(), np.random.random()
sqrt_r1 = np.sqrt(r1)
# Barycentric coordinates
u = 1 - sqrt_r1
v = sqrt_r1 * (1 - r2)
w = sqrt_r1 * r2
# Get vertices of the face
face = mesh.faces[face_idx]
vertices = mesh.vertices[face]
# Compute point using barycentric coordinates
point = u * vertices[0] + v * vertices[1] + w * vertices[2]
points_per_face.append(point)
points_per_face = np.array(points_per_face)
normals_per_face = mesh.face_normals[face_perm]
# Sample remaining points uniformly
num_remaining_points = num_points_uniform - num_faces
if num_remaining_points > 0:
points_remaining, face_indices_remaining = mesh.sample(num_remaining_points, return_index=True)
normals_remaining = mesh.face_normals[face_indices_remaining]
points_uniform = np.concatenate([points_per_face, points_remaining], axis=0)
normals_uniform = np.concatenate([normals_per_face, normals_remaining], axis=0)
face_indices = np.concatenate([face_perm, face_indices_remaining], axis=0)
else:
points_uniform = points_per_face
normals_uniform = normals_per_face
face_indices = face_perm
else:
points_uniform, face_indices = mesh.sample(num_points_uniform, return_index=True)
normals_uniform = mesh.face_normals[face_indices]
points = np.concatenate([points_sharp, points_uniform], axis=0)
normals = np.concatenate([normals_sharp, normals_uniform], axis=0)
sharp_flag = np.concatenate([
np.ones(len(points_sharp), dtype=np.bool_),
np.zeros(len(points_uniform), dtype=np.bool_)
], axis=0)
# For each sharp point, randomly select one of the adjacent faces from the edge
sharp_face_indices = np.zeros(len(points_sharp), dtype=np.int32)
for i, edge_idx in enumerate(edge_indices):
adjacent_faces = sharp_edge_faces[edge_idx]
# Randomly select one of the adjacent faces
sharp_face_indices[i] = np.random.choice(adjacent_faces)
face_indices = np.concatenate([
sharp_face_indices,
face_indices
], axis=0)
return points, normals, sharp_flag, face_indices
def prepare_inputs(mesh, num_points_global: int = 40000, num_points_decode: int = 2048, device: str = "cuda"):
"""Prepare inputs from a mesh file for model inference."""
sharp_point_ratio = DATA_CONFIG['sharp_point_ratio']
all_points, _, _, _ = sample_points(mesh, num_points_global, sharp_point_ratio)
points, normals, sharp_flag, face_indices = sample_points(mesh, num_points_decode, sharp_point_ratio, at_least_one_point_per_face=True)
if DATA_CONFIG['normalize_points']:
bbmin = np.concatenate([all_points, points], axis=0).min(0)
bbmax = np.concatenate([all_points, points], axis=0).max(0)
center = (bbmin + bbmax) * 0.5
scale = 1.0 / (bbmax - bbmin).max()
all_points = (all_points - center) * scale
points = (points - center) * scale
all_points = torch.from_numpy(all_points).to(device).float().unsqueeze(0)
points = torch.from_numpy(points).to(device).float().unsqueeze(0)
normals = torch.from_numpy(normals).to(device).float().unsqueeze(0)
partfield_model = get_partfield_model(device=device)
feats = obtain_partfield_feats(partfield_model, all_points, points)
return dict(xyz=points, normals=normals, feats=feats), sharp_flag, face_indices
def refine_part_ids_strict(mesh, face_part_ids):
"""
Refine face part IDs by treating each connected component (CC) in the mesh independently.
For each CC, all faces are labeled with the part ID that has the largest surface area in that CC.
Args:
mesh: trimesh object
face_part_ids: part ID for each face [num_faces]
Returns:
refined_face_part_ids: refined part ID for each face [num_faces]
"""
face_part_ids = face_part_ids.copy() # Don't modify the input
# Use trimesh's built-in connected components functionality
# mesh.face_adjacency gives pairs of face indices that share an edge
mesh_components = trimesh.graph.connected_components(
edges=mesh.face_adjacency,
nodes=np.arange(len(mesh.faces)),
min_len=1
)
# For each connected component, find the part ID with the largest surface area
for component in mesh_components:
if len(component) == 0:
continue
# Collect part IDs in this component and their surface areas
part_id_areas = {}
for face_idx in component:
part_id = face_part_ids[face_idx]
if part_id == -1:
continue # Skip unassigned faces
face_area = mesh.area_faces[face_idx]
if part_id not in part_id_areas:
part_id_areas[part_id] = 0.0
part_id_areas[part_id] += face_area
# Find the part ID with the largest area
if len(part_id_areas) == 0:
# No valid part IDs in this component, skip
continue
dominant_part_id = max(part_id_areas.keys(), key=lambda pid: part_id_areas[pid])
# Assign all faces in this component to the dominant part ID
for face_idx in component:
face_part_ids[face_idx] = dominant_part_id
return face_part_ids
def compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, current_face_part_ids, face_adjacency_dict):
"""
Compute part-specific connected components for faces in this mesh CC.
Returns a list of dicts with 'faces', 'part_id', and 'area'.
Two faces are in the same component if:
- They have the same part ID
- They are connected through faces of the same part ID
"""
components = []
# Get unique part IDs in this mesh CC
unique_part_ids = np.unique(current_face_part_ids[mesh_cc_faces])
for part_id in unique_part_ids:
if part_id == -1:
continue
# Get faces in this mesh CC with this part ID
mask = current_face_part_ids[mesh_cc_faces] == part_id
faces_with_part = mesh_cc_faces[mask]
if len(faces_with_part) == 0:
continue
# Convert to set for faster lookup
faces_with_part_set = set(faces_with_part)
# Build edges between these faces (both must have same part ID and be adjacent)
edges_for_part = []
for face_i in faces_with_part:
for face_j in face_adjacency_dict[face_i]:
if face_j in faces_with_part_set:
edges_for_part.append([face_i, face_j])
if len(edges_for_part) == 0:
# Each face is its own component
for face_i in faces_with_part:
components.append({
'faces': np.array([face_i]),
'part_id': part_id,
'area': mesh.area_faces[face_i]
})
else:
# Find connected components
edges_for_part = np.array(edges_for_part)
comps = trimesh.graph.connected_components(
edges=edges_for_part,
nodes=faces_with_part,
min_len=1
)
for comp in comps:
comp_faces = np.array(list(comp))
components.append({
'faces': comp_faces,
'part_id': part_id,
'area': np.sum(mesh.area_faces[comp_faces])
})
return components
def refine_part_ids_nonstrict(mesh, face_part_ids):
"""
Refine face part IDs to ensure each part ID forms a single connected component.
For each part ID, if there are multiple disconnected components, the smaller
components (by surface area) are reassigned based on adjacent faces' part IDs.
This is done iteratively until convergence or max iterations.
Args:
mesh: trimesh object
xyz: sampled points on the mesh [num_points, 3]
part_ids: part IDs for each sampled point [num_points]
face_indices: which face each point lies on (-1 means on edge) [num_points]
face_part_ids: initial part ID for each face [num_faces]
Returns:
refined_face_part_ids: refined part ID for each face [num_faces]
"""
face_part_ids_final = face_part_ids.copy() # Don't modify the input
# Step 1: Find connected components of the original mesh (immutable structure)
mesh_components = trimesh.graph.connected_components(
edges=mesh.face_adjacency,
nodes=np.arange(len(mesh.faces)),
min_len=1
)
mesh_components = [np.array(list(comp)) for comp in mesh_components]
# Step 2: Build face adjacency dict (immutable structure)
face_adjacency_dict = {i: set() for i in range(len(mesh.faces))}
for face_i, face_j in mesh.face_adjacency:
face_adjacency_dict[face_i].add(face_j)
face_adjacency_dict[face_j].add(face_i)
# Step 3: Process each mesh CC independently
for mesh_cc_faces in mesh_components:
done = False
while not done:
comps = compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, face_part_ids_final, face_adjacency_dict)
comps.sort(key=lambda c: c['area'])
part_id_areas = {}
for comp in comps:
pid = comp['part_id']
if pid not in part_id_areas:
part_id_areas[pid] = 0.0
part_id_areas[pid] += comp['area']
done = True
for comp_idx in range(len(comps)):
current_part_id = comps[comp_idx]['part_id']
if len([c for c in comps if c['part_id'] == current_part_id]) > 1:
done = False
# Find adjacent components
adjacent_part_ids = set()
current_faces_set = set(comps[comp_idx]['faces'])
for face_i in current_faces_set:
for face_j in face_adjacency_dict[face_i]:
if face_j in current_faces_set:
continue
adjacent_part_ids.add(face_part_ids_final[face_j])
chosen_part_id = max(adjacent_part_ids, key=lambda x: part_id_areas[x])
comps[comp_idx]['part_id'] = chosen_part_id
face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id
break
return face_part_ids_final
def find_part_ids_for_faces(mesh, part_ids, face_indices, strict=False):
"""
Assign part IDs to each face in the mesh.
Args:
mesh: trimesh object
xyz: sampled points on the mesh [num_points, 3]
part_ids: part IDs for each sampled point [num_points]
face_indices: which face each point lies on (-1 means on edge) [num_points]
Returns:
face_part_ids: part ID for each face [num_faces]
"""
num_faces = len(mesh.faces)
face_part_ids = np.full(num_faces, -1, dtype=np.int32)
# Step 1: Assign part IDs to faces that have points on them
# For each face, collect all points that lie on it and use majority vote
face_to_points = {}
for point_idx, face_idx in enumerate(face_indices):
if face_idx == -1: # Point is on an edge, ignore it
continue
if face_idx not in face_to_points:
face_to_points[face_idx] = []
face_to_points[face_idx].append(part_ids[point_idx])
# Assign part IDs based on majority vote from points
for face_idx, point_part_ids in face_to_points.items():
# Use bincount to find the majority part ID
counts = np.bincount(point_part_ids)
majority_part_id = np.argmax(counts)
face_part_ids[face_idx] = majority_part_id
if strict:
return refine_part_ids_strict(mesh, face_part_ids)
else:
return refine_part_ids_nonstrict(mesh, face_part_ids)
@torch.no_grad()
def infer_single_asset(
mesh,
up_dir,
model,
num_points,
min_part_confidence=0.0
):
mesh_transformed = mesh.copy()
if up_dir in ["x", "X"]:
rotation_matrix = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32)
mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T
elif up_dir in ["-x", "-X"]:
rotation_matrix = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], dtype=np.float32)
mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T
elif up_dir in ["y", "Y"]:
rotation_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32)
mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T
elif up_dir in ["-y", "-Y"]:
rotation_matrix = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32)
mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T
elif up_dir in ["z", "Z"]:
pass
elif up_dir in ["-z", "-Z"]:
rotation_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32)
mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T
else:
raise ValueError(f"Invalid up direction: {up_dir}")
# Normalize mesh to [-0.5, 0.5]^3 bounding box
bbox_min = mesh_transformed.vertices.min(axis=0)
bbox_max = mesh_transformed.vertices.max(axis=0)
center = (bbox_min + bbox_max) / 2
mesh_transformed.vertices -= center # Center the mesh
# Scale to fit in [-0.5, 0.5]^3
scale = (bbox_max - bbox_min).max() # Use the largest dimension
mesh_transformed.vertices /= scale
inputs, sharp_flag, face_indices = prepare_inputs(mesh_transformed, num_points_global=40000, num_points_decode=num_points)
with torch.no_grad():
outputs = model.infer(
xyz=inputs['xyz'],
feats=inputs['feats'],
normals=inputs['normals'],
output_all_hyps=True,
min_part_confidence=min_part_confidence
)
return outputs, face_indices, mesh_transformed
def save_articulated_meshes(mesh, face_indices, outputs, output_path, strict, animation_frames: int = 50, hyp_idx: int = 0):
part_ids = outputs[hyp_idx]['part_ids']
motion_hierarchy = outputs[hyp_idx]['motion_hierarchy']
is_part_revolute = outputs[hyp_idx]['is_part_revolute']
is_part_prismatic = outputs[hyp_idx]['is_part_prismatic']
revolute_plucker = outputs[hyp_idx]['revolute_plucker']
revolute_range = outputs[hyp_idx]['revolute_range']
prismatic_axis = outputs[hyp_idx]['prismatic_axis']
prismatic_range = outputs[hyp_idx]['prismatic_range']
face_part_ids = find_part_ids_for_faces(
mesh,
part_ids,
face_indices,
strict=strict
)
unique_part_ids = np.unique(face_part_ids)
num_parts = len(unique_part_ids)
print(f"Found {num_parts} unique parts")
mesh_parts_original = [mesh.submesh([face_part_ids == part_id], append=True) for part_id in unique_part_ids]
mesh_parts_segmented = create_textured_mesh_parts([mp.copy() for mp in mesh_parts_original])
# Create axes
axes = []
for i, mesh_part in enumerate(mesh_parts_segmented):
part_id = unique_part_ids[i]
if is_part_revolute[part_id]:
axis, point = plucker_to_axis_point(revolute_plucker[part_id])
arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, fixed_point=point, extension=0.2)
axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_REVOLUTE, radius=0.01, radius_tip=0.018))
# Add rings at arrow_start and arrow_end
arrow_dir = arrow_end - arrow_start
axes.append(create_ring(arrow_start, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE))
axes.append(create_ring(arrow_end, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE))
elif is_part_prismatic[part_id]:
axis = prismatic_axis[part_id]
arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, extension=0.2)
axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_PRISMATIC, radius=0.01, radius_tip=0.018))
trimesh.Scene(mesh_parts_segmented + axes).export(Path(output_path) / "mesh_parts_with_axes.glb")
print("Exporting animated GLB files...")
try:
export_animated_glb_file(
mesh_parts_original,
unique_part_ids,
motion_hierarchy,
is_part_revolute,
is_part_prismatic,
revolute_plucker,
revolute_range,
prismatic_axis,
prismatic_range,
animation_frames,
str(Path(output_path) / "animated_textured.glb"),
include_axes=False,
axes_meshes=None
)
except Exception as e:
print(f"Error exporting animated.glb: {e}")
import traceback
traceback.print_exc()
return (
mesh_parts_original,
unique_part_ids,
motion_hierarchy,
is_part_revolute,
is_part_prismatic,
revolute_plucker,
revolute_range,
prismatic_axis,
prismatic_range
)