Spaces:
Running
on
L40S
Running
on
L40S
File size: 7,319 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import torch
import numpy as np
import pymeshlab as ml
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes
import torch
import torch.nn.functional as F
import trimesh
from pymeshlab import PercentageValue
import open3d as o3d
def tensor2variable(tensor, device):
# [1,23,3,3]
return torch.tensor(tensor, device=device, requires_grad=True)
def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
def fix_vert_color_glb(mesh_path):
from pygltflib import GLTF2, Material, PbrMetallicRoughness
obj1 = GLTF2().load(mesh_path)
obj1.meshes[0].primitives[0].material = 0
obj1.materials.append(Material(
pbrMetallicRoughness = PbrMetallicRoughness(
baseColorFactor = [1.0, 1.0, 1.0, 1.0],
metallicFactor = 0.,
roughnessFactor = 1.0,
),
emissiveFactor = [0.0, 0.0, 0.0],
doubleSided = True,
))
obj1.save(mesh_path)
def srgb_to_linear(c_srgb):
c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
return c_linear.clip(0, 1.)
def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
# convert from pytorch3d meshes to trimesh mesh
vertices = meshes.verts_packed().cpu().float().numpy()
triangles = meshes.faces_packed().cpu().long().numpy()
np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
if save_glb_path.endswith(".glb"):
# rotate 180 along +Y
vertices[:, [0, 2]] = -vertices[:, [0, 2]]
if apply_sRGB_to_LinearRGB:
np_color = srgb_to_linear(np_color)
assert vertices.shape[0] == np_color.shape[0]
assert np_color.shape[1] == 3
assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
mesh.remove_unreferenced_vertices()
# save mesh
mesh.export(save_glb_path)
if save_glb_path.endswith(".glb"):
fix_vert_color_glb(save_glb_path)
print(f"saving to {save_glb_path}")
def load_mesh_with_trimesh(file_name, file_type=None):
mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
if isinstance(mesh, trimesh.Scene):
assert len(mesh.geometry) > 0
# save to obj first and load again to avoid offset issue
from io import BytesIO
with BytesIO() as f:
mesh.export(f, file_type="obj")
f.seek(0)
mesh = trimesh.load(f, file_type="obj")
if isinstance(mesh, trimesh.Scene):
# we lose texture information here
mesh = trimesh.util.concatenate(
tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
for g in mesh.geometry.values()))
assert isinstance(mesh, trimesh.Trimesh)
vertices = torch.from_numpy(mesh.vertices).T
faces = torch.from_numpy(mesh.faces).T
colors = None
if mesh.visual is not None:
if hasattr(mesh.visual, 'vertex_colors'):
colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
if colors is None:
# print("Warning: no vertex color found in mesh! Filling it with gray.")
colors = torch.ones_like(vertices) * 0.5
return vertices, faces, colors
def meshlab_mesh_to_py3dmesh(mesh: ml.Mesh) -> Meshes:
verts = torch.from_numpy(mesh.vertex_matrix()).float()
faces = torch.from_numpy(mesh.face_matrix()).long()
colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
textures = TexturesVertex(verts_features=[colors])
return Meshes(verts=[verts], faces=[faces], textures=textures)
def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> ml.Mesh:
colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
m1 = ml.Mesh(
vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
v_color_matrix=colors_in)
return m1
def to_pyml_mesh(vertices,faces):
m1 = ml.Mesh(
vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
face_matrix=faces.cpu().long().numpy().astype(np.int32),
)
return m1
def to_py3d_mesh(vertices, faces, normals=None):
from pytorch3d.structures import Meshes
from pytorch3d.renderer.mesh.textures import TexturesVertex
mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
if normals is None:
normals = mesh.verts_normals_packed()
# set normals as vertext colors
mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
return mesh
def from_py3d_mesh(mesh):
return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
ms = ml.MeshSet()
ms.add_mesh(pyml_mesh, "cube_mesh")
if apply_smooth:
ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
if apply_sub_divide: # 5s, slow
ms.apply_filter("meshing_repair_non_manifold_vertices")
ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
return meshlab_mesh_to_py3dmesh(ms.current_mesh())
def post_process_mesh(mesh, cluster_to_keep=1000):
"""
Post-process a mesh to filter out floaters and disconnected parts
"""
import copy
print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
mesh_0 = copy.deepcopy(mesh)
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())
triangle_clusters = np.asarray(triangle_clusters)
cluster_n_triangles = np.asarray(cluster_n_triangles)
cluster_area = np.asarray(cluster_area)
n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
mesh_0.remove_triangles_by_mask(triangles_to_remove)
mesh_0.remove_unreferenced_vertices()
mesh_0.remove_degenerate_triangles()
print("num vertices raw {}".format(len(mesh.vertices)))
print("num vertices post {}".format(len(mesh_0.vertices)))
return mesh_0 |