|
from src.utils.typing_utils import * |
|
|
|
import os |
|
import numpy as np |
|
import trimesh |
|
import torch |
|
|
|
def normalize_mesh( |
|
mesh: Union[trimesh.Trimesh, trimesh.Scene], |
|
scale: float = 2.0, |
|
): |
|
|
|
|
|
bbox = mesh.bounding_box |
|
translation = -bbox.centroid |
|
scale = scale / bbox.primitive.extents.max() |
|
mesh.apply_translation(translation) |
|
mesh.apply_scale(scale) |
|
return mesh |
|
|
|
def remove_overlapping_vertices(mesh: trimesh.Trimesh, reserve_material: bool = False): |
|
if not isinstance(mesh, trimesh.Trimesh): |
|
raise ValueError("Input mesh is not a trimesh.Trimesh object.") |
|
vertices = mesh.vertices |
|
faces = mesh.faces |
|
unique_vertices, index_map, inverse_map = np.unique( |
|
vertices, axis=0, return_index=True, return_inverse=True |
|
) |
|
clean_faces = inverse_map[faces] |
|
clean_mesh = trimesh.Trimesh(vertices=unique_vertices, faces=clean_faces, process=True) |
|
if reserve_material: |
|
uv = mesh.visual.uv |
|
material = mesh.visual.material |
|
clean_uv = uv[index_map] |
|
clean_visual = trimesh.visual.TextureVisuals(uv=clean_uv, material=material) |
|
clean_mesh.visual = clean_visual |
|
return clean_mesh |
|
|
|
RGB = [ |
|
(82, 170, 220), |
|
(215, 91, 78), |
|
(45, 136, 117), |
|
(247, 172, 83), |
|
(124, 121, 121), |
|
(127, 171, 209), |
|
(243, 152, 101), |
|
(145, 204, 192), |
|
(150, 59, 121), |
|
(181, 206, 78), |
|
(189, 119, 149), |
|
(199, 193, 222), |
|
(200, 151, 54), |
|
(236, 110, 102), |
|
(238, 182, 212), |
|
] |
|
|
|
|
|
def get_colored_mesh_composition( |
|
meshes: Union[List[trimesh.Trimesh], trimesh.Scene], |
|
is_random: bool = True, |
|
is_sorted: bool = False, |
|
RGB: List[Tuple] = RGB |
|
): |
|
if isinstance(meshes, trimesh.Scene): |
|
meshes = meshes.dump() |
|
if is_sorted: |
|
volumes = [] |
|
for mesh in meshes: |
|
try: |
|
volume = mesh.volume |
|
except: |
|
volume = 0.0 |
|
volumes.append(volume) |
|
|
|
meshes = [x for _, x in sorted(zip(volumes, meshes), key=lambda pair: pair[0], reverse=True)] |
|
colored_scene = trimesh.Scene() |
|
for idx, mesh in enumerate(meshes): |
|
if is_random: |
|
color = (np.random.rand(3) * 256).astype(int) |
|
else: |
|
color = np.array(RGB[idx % len(RGB)]) |
|
mesh.visual = trimesh.visual.ColorVisuals( |
|
mesh=mesh, |
|
vertex_colors=color, |
|
) |
|
colored_scene.add_geometry(mesh) |
|
return colored_scene |
|
|
|
def mesh_to_surface( |
|
mesh: trimesh.Trimesh, |
|
num_pc: int = 204800, |
|
clip_to_num_vertices: bool = False, |
|
return_dict: bool = False, |
|
): |
|
|
|
|
|
if clip_to_num_vertices: |
|
num_pc = min(num_pc, mesh.vertices.shape[0]) |
|
points, face_indices = mesh.sample(num_pc, return_index=True) |
|
normals = mesh.face_normals[face_indices] |
|
if return_dict: |
|
return { |
|
"surface_points": points, |
|
"surface_normals": normals, |
|
} |
|
return points, normals |
|
|
|
def scene_to_parts( |
|
mesh: trimesh.Scene, |
|
normalize: bool = True, |
|
scale: float = 2.0, |
|
num_part_pc: int = 204800, |
|
clip_to_num_part_vertices: bool = False, |
|
return_type: Literal["mesh", "point"] = "mesh", |
|
) -> Union[List[trimesh.Geometry], List[Dict[str, np.ndarray]]]: |
|
if not isinstance(mesh, trimesh.Scene): |
|
raise ValueError("mesh must be a trimesh.Scene object") |
|
if normalize: |
|
mesh = normalize_mesh(mesh, scale=scale) |
|
parts: List[trimesh.Geometry] = mesh.dump() |
|
if return_type == "point": |
|
datas: List[Dict[str, np.ndarray]] = [] |
|
for geom in parts: |
|
data = mesh_to_surface( |
|
geom, |
|
num_pc=num_part_pc, |
|
clip_to_num_vertices=clip_to_num_part_vertices, |
|
return_dict=True, |
|
) |
|
datas.append(data) |
|
return datas |
|
elif return_type == "mesh": |
|
return parts |
|
else: |
|
raise ValueError("return_type must be 'mesh' or 'point'") |
|
|
|
def get_center(mesh: trimesh.Trimesh, method: Literal['mass', 'bbox']): |
|
if method == 'mass': |
|
return mesh.center_mass |
|
elif method =='bbox': |
|
return mesh.bounding_box.centroid |
|
else: |
|
raise ValueError('type must be mass or bbox') |
|
|
|
def get_direction(vector: np.ndarray): |
|
return vector / np.linalg.norm(vector) |
|
|
|
def move_mesh_by_center(mesh: trimesh.Trimesh, scale: float, method: Literal['mass', 'bbox'] = 'mass'): |
|
offset = scale - 1 |
|
center = get_center(mesh, method) |
|
direction = get_direction(center) |
|
translation = direction * offset |
|
mesh = mesh.copy() |
|
mesh.apply_translation(translation) |
|
return mesh |
|
|
|
def move_meshes_by_center(meshes: Union[List[trimesh.Trimesh], trimesh.Scene], scale: float): |
|
if isinstance(meshes, trimesh.Scene): |
|
meshes = meshes.dump() |
|
moved_meshes = [] |
|
for mesh in meshes: |
|
moved_mesh = move_mesh_by_center(mesh, scale) |
|
moved_meshes.append(moved_mesh) |
|
moved_meshes = trimesh.Scene(moved_meshes) |
|
return moved_meshes |
|
|
|
def get_series_splited_meshes(meshes: List[trimesh.Trimesh], scale: float, num_steps: int) -> List[trimesh.Scene]: |
|
series_meshes = [] |
|
for i in range(num_steps): |
|
temp_scale = 1 + (scale - 1) * i / (num_steps - 1) |
|
temp_meshes = move_meshes_by_center(meshes, temp_scale) |
|
series_meshes.append(temp_meshes) |
|
return series_meshes |
|
|
|
def load_surface(data, num_pc=204800): |
|
|
|
surface = data["surface_points"] |
|
normal = data["surface_normals"] |
|
|
|
rng = np.random.default_rng() |
|
ind = rng.choice(surface.shape[0], num_pc, replace=False) |
|
surface = torch.FloatTensor(surface[ind]) |
|
normal = torch.FloatTensor(normal[ind]) |
|
surface = torch.cat([surface, normal], dim=-1) |
|
|
|
return surface |
|
|
|
def load_surfaces(surfaces, num_pc=204800): |
|
surfaces = [load_surface(surface, num_pc) for surface in surfaces] |
|
surfaces = torch.stack(surfaces, dim=0) |
|
return surfaces |