Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| from pathlib import Path | |
| import torch | |
| import trimesh | |
| from particulate.visualization_utils import ( | |
| get_3D_arrow_on_points, | |
| create_arrow, | |
| create_ring, | |
| create_textured_mesh_parts, | |
| ARROW_COLOR_REVOLUTE, | |
| ARROW_COLOR_PRISMATIC | |
| ) | |
| from particulate.articulation_utils import plucker_to_axis_point | |
| from particulate.export_utils import export_animated_glb_file | |
| from partfield_utils import obtain_partfield_feats, get_partfield_model | |
| DATA_CONFIG = { | |
| 'sharp_point_ratio': 0.5, | |
| 'normalize_points': True | |
| } | |
| def sharp_sample_pointcloud(mesh, num_points: int = 8192): | |
| V = mesh.vertices | |
| N = mesh.face_normals | |
| F = mesh.faces | |
| # Build edge-to-faces mapping | |
| # Each edge is represented as (min_vertex_id, max_vertex_id) to ensure consistent ordering | |
| edge_to_faces = {} | |
| for face_idx in range(len(F)): | |
| face = F[face_idx] | |
| # Get the three edges of this face | |
| edges = [ | |
| (face[0], face[1]), | |
| (face[1], face[2]), | |
| (face[2], face[0]) | |
| ] | |
| for edge in edges: | |
| # Normalize edge ordering (min vertex first) | |
| edge_key = tuple(sorted(edge)) | |
| if edge_key not in edge_to_faces: | |
| edge_to_faces[edge_key] = [] | |
| edge_to_faces[edge_key].append(face_idx) | |
| # Identify sharp edges based on face normal angles and store their averaged normals | |
| sharp_edges = [] | |
| sharp_edge_normals = [] | |
| sharp_edge_faces = [] # Store adjacent faces for each sharp edge | |
| cos_30 = np.cos(np.radians(30)) # ≈ 0.866 | |
| cos_150 = np.cos(np.radians(150)) # ≈ -0.866 | |
| for edge_key, face_indices in edge_to_faces.items(): | |
| # Check if edge has >= 2 faces | |
| if len(face_indices) < 2: | |
| continue | |
| # Check all pairs of face normals | |
| is_sharp = False | |
| for i in range(len(face_indices)): | |
| for j in range(i + 1, len(face_indices)): | |
| n1 = N[face_indices[i]] | |
| n2 = N[face_indices[j]] | |
| dot_product = np.dot(n1, n2) | |
| # Check if angle is between 30 and 150 degrees | |
| if cos_150 < dot_product < cos_30 and np.linalg.norm(n1) > 1e-8 and np.linalg.norm(n2) > 1e-8: | |
| is_sharp = True | |
| sharp_edges.append(edge_key) | |
| averaged_normal = (n1 + n2) / 2 | |
| sharp_edge_normals.append(averaged_normal) | |
| sharp_edge_faces.append(face_indices) # Store all adjacent faces | |
| break | |
| if is_sharp: | |
| break | |
| # Convert sharp edges to vertex arrays | |
| edge_a = np.array([edge[0] for edge in sharp_edges], dtype=np.int32) | |
| edge_b = np.array([edge[1] for edge in sharp_edges], dtype=np.int32) | |
| sharp_edge_normals = np.array(sharp_edge_normals, dtype=np.float64) | |
| # Handle the case where there are no sharp edges | |
| if len(sharp_edges) == 0: | |
| # Return empty arrays with appropriate shape | |
| samples = np.zeros((0, 3), dtype=np.float64) | |
| normals = np.zeros((0, 3), dtype=np.float64) | |
| edge_indices = np.zeros((0,), dtype=np.int32) | |
| return samples, normals, edge_indices, sharp_edge_faces | |
| sharp_verts_a = V[edge_a] | |
| sharp_verts_b = V[edge_b] | |
| weights = np.linalg.norm(sharp_verts_b - sharp_verts_a, axis=-1) | |
| weights /= np.sum(weights) | |
| random_number = np.random.rand(num_points) | |
| w = np.random.rand(num_points, 1) | |
| index = np.searchsorted(weights.cumsum(), random_number) | |
| samples = w * sharp_verts_a[index] + (1 - w) * sharp_verts_b[index] | |
| normals = sharp_edge_normals[index] # Use the averaged face normal for each edge | |
| return samples, normals, index, sharp_edge_faces | |
| def sample_points(mesh, num_points, sharp_point_ratio, at_least_one_point_per_face=False): | |
| """Sample points from mesh using sharp edge and uniform sampling.""" | |
| num_points_sharp_edges = int(num_points * sharp_point_ratio) | |
| num_points_uniform = num_points - num_points_sharp_edges | |
| points_sharp, normals_sharp, edge_indices, sharp_edge_faces = sharp_sample_pointcloud(mesh, num_points_sharp_edges) | |
| # If no sharp edges were found, sample all points uniformly | |
| if len(points_sharp) == 0 and sharp_point_ratio > 0: | |
| print(f"Warning: No sharp edges found, sampling all points uniformly") | |
| num_points_uniform = num_points | |
| if at_least_one_point_per_face: | |
| num_faces = len(mesh.faces) | |
| if num_points_uniform < num_faces: | |
| raise ValueError( | |
| "Unable to sample at least one point per face: " | |
| f"{num_faces} faces > 51.2k" | |
| ) | |
| # Get a random permutation of face indices | |
| face_perm = np.random.permutation(num_faces) | |
| # Sample one point from each face | |
| points_per_face = [] | |
| for face_idx in face_perm: | |
| # Sample one random point on this face using barycentric coordinates | |
| r1, r2 = np.random.random(), np.random.random() | |
| sqrt_r1 = np.sqrt(r1) | |
| # Barycentric coordinates | |
| u = 1 - sqrt_r1 | |
| v = sqrt_r1 * (1 - r2) | |
| w = sqrt_r1 * r2 | |
| # Get vertices of the face | |
| face = mesh.faces[face_idx] | |
| vertices = mesh.vertices[face] | |
| # Compute point using barycentric coordinates | |
| point = u * vertices[0] + v * vertices[1] + w * vertices[2] | |
| points_per_face.append(point) | |
| points_per_face = np.array(points_per_face) | |
| normals_per_face = mesh.face_normals[face_perm] | |
| # Sample remaining points uniformly | |
| num_remaining_points = num_points_uniform - num_faces | |
| if num_remaining_points > 0: | |
| points_remaining, face_indices_remaining = mesh.sample(num_remaining_points, return_index=True) | |
| normals_remaining = mesh.face_normals[face_indices_remaining] | |
| points_uniform = np.concatenate([points_per_face, points_remaining], axis=0) | |
| normals_uniform = np.concatenate([normals_per_face, normals_remaining], axis=0) | |
| face_indices = np.concatenate([face_perm, face_indices_remaining], axis=0) | |
| else: | |
| points_uniform = points_per_face | |
| normals_uniform = normals_per_face | |
| face_indices = face_perm | |
| else: | |
| points_uniform, face_indices = mesh.sample(num_points_uniform, return_index=True) | |
| normals_uniform = mesh.face_normals[face_indices] | |
| points = np.concatenate([points_sharp, points_uniform], axis=0) | |
| normals = np.concatenate([normals_sharp, normals_uniform], axis=0) | |
| sharp_flag = np.concatenate([ | |
| np.ones(len(points_sharp), dtype=np.bool_), | |
| np.zeros(len(points_uniform), dtype=np.bool_) | |
| ], axis=0) | |
| # For each sharp point, randomly select one of the adjacent faces from the edge | |
| sharp_face_indices = np.zeros(len(points_sharp), dtype=np.int32) | |
| for i, edge_idx in enumerate(edge_indices): | |
| adjacent_faces = sharp_edge_faces[edge_idx] | |
| # Randomly select one of the adjacent faces | |
| sharp_face_indices[i] = np.random.choice(adjacent_faces) | |
| face_indices = np.concatenate([ | |
| sharp_face_indices, | |
| face_indices | |
| ], axis=0) | |
| return points, normals, sharp_flag, face_indices | |
| def prepare_inputs(mesh, num_points_global: int = 40000, num_points_decode: int = 2048, device: str = "cuda"): | |
| """Prepare inputs from a mesh file for model inference.""" | |
| sharp_point_ratio = DATA_CONFIG['sharp_point_ratio'] | |
| all_points, _, _, _ = sample_points(mesh, num_points_global, sharp_point_ratio) | |
| points, normals, sharp_flag, face_indices = sample_points(mesh, num_points_decode, sharp_point_ratio, at_least_one_point_per_face=True) | |
| if DATA_CONFIG['normalize_points']: | |
| bbmin = np.concatenate([all_points, points], axis=0).min(0) | |
| bbmax = np.concatenate([all_points, points], axis=0).max(0) | |
| center = (bbmin + bbmax) * 0.5 | |
| scale = 1.0 / (bbmax - bbmin).max() | |
| all_points = (all_points - center) * scale | |
| points = (points - center) * scale | |
| all_points = torch.from_numpy(all_points).to(device).float().unsqueeze(0) | |
| points = torch.from_numpy(points).to(device).float().unsqueeze(0) | |
| normals = torch.from_numpy(normals).to(device).float().unsqueeze(0) | |
| partfield_model = get_partfield_model(device=device) | |
| feats = obtain_partfield_feats(partfield_model, all_points, points) | |
| return dict(xyz=points, normals=normals, feats=feats), sharp_flag, face_indices | |
| def refine_part_ids_strict(mesh, face_part_ids): | |
| """ | |
| Refine face part IDs by treating each connected component (CC) in the mesh independently. | |
| For each CC, all faces are labeled with the part ID that has the largest surface area in that CC. | |
| Args: | |
| mesh: trimesh object | |
| face_part_ids: part ID for each face [num_faces] | |
| Returns: | |
| refined_face_part_ids: refined part ID for each face [num_faces] | |
| """ | |
| face_part_ids = face_part_ids.copy() # Don't modify the input | |
| # Use trimesh's built-in connected components functionality | |
| # mesh.face_adjacency gives pairs of face indices that share an edge | |
| mesh_components = trimesh.graph.connected_components( | |
| edges=mesh.face_adjacency, | |
| nodes=np.arange(len(mesh.faces)), | |
| min_len=1 | |
| ) | |
| # For each connected component, find the part ID with the largest surface area | |
| for component in mesh_components: | |
| if len(component) == 0: | |
| continue | |
| # Collect part IDs in this component and their surface areas | |
| part_id_areas = {} | |
| for face_idx in component: | |
| part_id = face_part_ids[face_idx] | |
| if part_id == -1: | |
| continue # Skip unassigned faces | |
| face_area = mesh.area_faces[face_idx] | |
| if part_id not in part_id_areas: | |
| part_id_areas[part_id] = 0.0 | |
| part_id_areas[part_id] += face_area | |
| # Find the part ID with the largest area | |
| if len(part_id_areas) == 0: | |
| # No valid part IDs in this component, skip | |
| continue | |
| dominant_part_id = max(part_id_areas.keys(), key=lambda pid: part_id_areas[pid]) | |
| # Assign all faces in this component to the dominant part ID | |
| for face_idx in component: | |
| face_part_ids[face_idx] = dominant_part_id | |
| return face_part_ids | |
| def compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, current_face_part_ids, face_adjacency_dict): | |
| """ | |
| Compute part-specific connected components for faces in this mesh CC. | |
| Returns a list of dicts with 'faces', 'part_id', and 'area'. | |
| Two faces are in the same component if: | |
| - They have the same part ID | |
| - They are connected through faces of the same part ID | |
| """ | |
| components = [] | |
| # Get unique part IDs in this mesh CC | |
| unique_part_ids = np.unique(current_face_part_ids[mesh_cc_faces]) | |
| for part_id in unique_part_ids: | |
| if part_id == -1: | |
| continue | |
| # Get faces in this mesh CC with this part ID | |
| mask = current_face_part_ids[mesh_cc_faces] == part_id | |
| faces_with_part = mesh_cc_faces[mask] | |
| if len(faces_with_part) == 0: | |
| continue | |
| # Convert to set for faster lookup | |
| faces_with_part_set = set(faces_with_part) | |
| # Build edges between these faces (both must have same part ID and be adjacent) | |
| edges_for_part = [] | |
| for face_i in faces_with_part: | |
| for face_j in face_adjacency_dict[face_i]: | |
| if face_j in faces_with_part_set: | |
| edges_for_part.append([face_i, face_j]) | |
| if len(edges_for_part) == 0: | |
| # Each face is its own component | |
| for face_i in faces_with_part: | |
| components.append({ | |
| 'faces': np.array([face_i]), | |
| 'part_id': part_id, | |
| 'area': mesh.area_faces[face_i] | |
| }) | |
| else: | |
| # Find connected components | |
| edges_for_part = np.array(edges_for_part) | |
| comps = trimesh.graph.connected_components( | |
| edges=edges_for_part, | |
| nodes=faces_with_part, | |
| min_len=1 | |
| ) | |
| for comp in comps: | |
| comp_faces = np.array(list(comp)) | |
| components.append({ | |
| 'faces': comp_faces, | |
| 'part_id': part_id, | |
| 'area': np.sum(mesh.area_faces[comp_faces]) | |
| }) | |
| return components | |
| def refine_part_ids_nonstrict(mesh, face_part_ids): | |
| """ | |
| Refine face part IDs to ensure each part ID forms a single connected component. | |
| For each part ID, if there are multiple disconnected components, the smaller | |
| components (by surface area) are reassigned based on adjacent faces' part IDs. | |
| This is done iteratively until convergence or max iterations. | |
| Args: | |
| mesh: trimesh object | |
| xyz: sampled points on the mesh [num_points, 3] | |
| part_ids: part IDs for each sampled point [num_points] | |
| face_indices: which face each point lies on (-1 means on edge) [num_points] | |
| face_part_ids: initial part ID for each face [num_faces] | |
| Returns: | |
| refined_face_part_ids: refined part ID for each face [num_faces] | |
| """ | |
| face_part_ids_final = face_part_ids.copy() # Don't modify the input | |
| # Step 1: Find connected components of the original mesh (immutable structure) | |
| mesh_components = trimesh.graph.connected_components( | |
| edges=mesh.face_adjacency, | |
| nodes=np.arange(len(mesh.faces)), | |
| min_len=1 | |
| ) | |
| mesh_components = [np.array(list(comp)) for comp in mesh_components] | |
| # Step 2: Build face adjacency dict (immutable structure) | |
| face_adjacency_dict = {i: set() for i in range(len(mesh.faces))} | |
| for face_i, face_j in mesh.face_adjacency: | |
| face_adjacency_dict[face_i].add(face_j) | |
| face_adjacency_dict[face_j].add(face_i) | |
| # Step 3: Process each mesh CC independently | |
| for mesh_cc_faces in mesh_components: | |
| done = False | |
| while not done: | |
| comps = compute_part_components_for_mesh_cc(mesh, mesh_cc_faces, face_part_ids_final, face_adjacency_dict) | |
| comps.sort(key=lambda c: c['area']) | |
| part_id_areas = {} | |
| for comp in comps: | |
| pid = comp['part_id'] | |
| if pid not in part_id_areas: | |
| part_id_areas[pid] = 0.0 | |
| part_id_areas[pid] += comp['area'] | |
| done = True | |
| for comp_idx in range(len(comps)): | |
| current_part_id = comps[comp_idx]['part_id'] | |
| if len([c for c in comps if c['part_id'] == current_part_id]) > 1: | |
| done = False | |
| # Find adjacent components | |
| adjacent_part_ids = set() | |
| current_faces_set = set(comps[comp_idx]['faces']) | |
| for face_i in current_faces_set: | |
| for face_j in face_adjacency_dict[face_i]: | |
| if face_j in current_faces_set: | |
| continue | |
| adjacent_part_ids.add(face_part_ids_final[face_j]) | |
| chosen_part_id = max(adjacent_part_ids, key=lambda x: part_id_areas[x]) | |
| comps[comp_idx]['part_id'] = chosen_part_id | |
| face_part_ids_final[comps[comp_idx]['faces']] = chosen_part_id | |
| break | |
| return face_part_ids_final | |
| def find_part_ids_for_faces(mesh, part_ids, face_indices, strict=False): | |
| """ | |
| Assign part IDs to each face in the mesh. | |
| Args: | |
| mesh: trimesh object | |
| xyz: sampled points on the mesh [num_points, 3] | |
| part_ids: part IDs for each sampled point [num_points] | |
| face_indices: which face each point lies on (-1 means on edge) [num_points] | |
| Returns: | |
| face_part_ids: part ID for each face [num_faces] | |
| """ | |
| num_faces = len(mesh.faces) | |
| face_part_ids = np.full(num_faces, -1, dtype=np.int32) | |
| # Step 1: Assign part IDs to faces that have points on them | |
| # For each face, collect all points that lie on it and use majority vote | |
| face_to_points = {} | |
| for point_idx, face_idx in enumerate(face_indices): | |
| if face_idx == -1: # Point is on an edge, ignore it | |
| continue | |
| if face_idx not in face_to_points: | |
| face_to_points[face_idx] = [] | |
| face_to_points[face_idx].append(part_ids[point_idx]) | |
| # Assign part IDs based on majority vote from points | |
| for face_idx, point_part_ids in face_to_points.items(): | |
| # Use bincount to find the majority part ID | |
| counts = np.bincount(point_part_ids) | |
| majority_part_id = np.argmax(counts) | |
| face_part_ids[face_idx] = majority_part_id | |
| if strict: | |
| return refine_part_ids_strict(mesh, face_part_ids) | |
| else: | |
| return refine_part_ids_nonstrict(mesh, face_part_ids) | |
| def infer_single_asset( | |
| mesh, | |
| up_dir, | |
| model, | |
| num_points, | |
| min_part_confidence=0.0 | |
| ): | |
| mesh_transformed = mesh.copy() | |
| if up_dir in ["x", "X"]: | |
| rotation_matrix = np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]], dtype=np.float32) | |
| mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T | |
| elif up_dir in ["-x", "-X"]: | |
| rotation_matrix = np.array([[0, 0, 1], [0, 1, 0], [-1, 0, 0]], dtype=np.float32) | |
| mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T | |
| elif up_dir in ["y", "Y"]: | |
| rotation_matrix = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], dtype=np.float32) | |
| mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T | |
| elif up_dir in ["-y", "-Y"]: | |
| rotation_matrix = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=np.float32) | |
| mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T | |
| elif up_dir in ["z", "Z"]: | |
| pass | |
| elif up_dir in ["-z", "-Z"]: | |
| rotation_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]], dtype=np.float32) | |
| mesh_transformed.vertices = mesh_transformed.vertices @ rotation_matrix.T | |
| else: | |
| raise ValueError(f"Invalid up direction: {up_dir}") | |
| # Normalize mesh to [-0.5, 0.5]^3 bounding box | |
| bbox_min = mesh_transformed.vertices.min(axis=0) | |
| bbox_max = mesh_transformed.vertices.max(axis=0) | |
| center = (bbox_min + bbox_max) / 2 | |
| mesh_transformed.vertices -= center # Center the mesh | |
| # Scale to fit in [-0.5, 0.5]^3 | |
| scale = (bbox_max - bbox_min).max() # Use the largest dimension | |
| mesh_transformed.vertices /= scale | |
| inputs, sharp_flag, face_indices = prepare_inputs(mesh_transformed, num_points_global=40000, num_points_decode=num_points) | |
| with torch.no_grad(): | |
| outputs = model.infer( | |
| xyz=inputs['xyz'], | |
| feats=inputs['feats'], | |
| normals=inputs['normals'], | |
| output_all_hyps=True, | |
| min_part_confidence=min_part_confidence | |
| ) | |
| return outputs, face_indices, mesh_transformed | |
| def save_articulated_meshes(mesh, face_indices, outputs, output_path, strict, animation_frames: int = 50, hyp_idx: int = 0): | |
| part_ids = outputs[hyp_idx]['part_ids'] | |
| motion_hierarchy = outputs[hyp_idx]['motion_hierarchy'] | |
| is_part_revolute = outputs[hyp_idx]['is_part_revolute'] | |
| is_part_prismatic = outputs[hyp_idx]['is_part_prismatic'] | |
| revolute_plucker = outputs[hyp_idx]['revolute_plucker'] | |
| revolute_range = outputs[hyp_idx]['revolute_range'] | |
| prismatic_axis = outputs[hyp_idx]['prismatic_axis'] | |
| prismatic_range = outputs[hyp_idx]['prismatic_range'] | |
| face_part_ids = find_part_ids_for_faces( | |
| mesh, | |
| part_ids, | |
| face_indices, | |
| strict=strict | |
| ) | |
| unique_part_ids = np.unique(face_part_ids) | |
| num_parts = len(unique_part_ids) | |
| print(f"Found {num_parts} unique parts") | |
| mesh_parts_original = [mesh.submesh([face_part_ids == part_id], append=True) for part_id in unique_part_ids] | |
| mesh_parts_segmented = create_textured_mesh_parts([mp.copy() for mp in mesh_parts_original]) | |
| # Create axes | |
| axes = [] | |
| for i, mesh_part in enumerate(mesh_parts_segmented): | |
| part_id = unique_part_ids[i] | |
| if is_part_revolute[part_id]: | |
| axis, point = plucker_to_axis_point(revolute_plucker[part_id]) | |
| arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, fixed_point=point, extension=0.2) | |
| axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_REVOLUTE, radius=0.01, radius_tip=0.018)) | |
| # Add rings at arrow_start and arrow_end | |
| arrow_dir = arrow_end - arrow_start | |
| axes.append(create_ring(arrow_start, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) | |
| axes.append(create_ring(arrow_end, arrow_dir, major_radius=0.03, minor_radius=0.006, color=ARROW_COLOR_REVOLUTE)) | |
| elif is_part_prismatic[part_id]: | |
| axis = prismatic_axis[part_id] | |
| arrow_start, arrow_end = get_3D_arrow_on_points(axis, mesh_part.vertices, extension=0.2) | |
| axes.append(create_arrow(arrow_start, arrow_end, color=ARROW_COLOR_PRISMATIC, radius=0.01, radius_tip=0.018)) | |
| trimesh.Scene(mesh_parts_segmented + axes).export(Path(output_path) / "mesh_parts_with_axes.glb") | |
| print("Exporting animated GLB files...") | |
| try: | |
| export_animated_glb_file( | |
| mesh_parts_original, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range, | |
| animation_frames, | |
| str(Path(output_path) / "animated_textured.glb"), | |
| include_axes=False, | |
| axes_meshes=None | |
| ) | |
| except Exception as e: | |
| print(f"Error exporting animated.glb: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return ( | |
| mesh_parts_original, | |
| unique_part_ids, | |
| motion_hierarchy, | |
| is_part_revolute, | |
| is_part_prismatic, | |
| revolute_plucker, | |
| revolute_range, | |
| prismatic_axis, | |
| prismatic_range | |
| ) |