Yiwen-ntu commited on
Commit
535964c
1 Parent(s): acc6365

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +177 -0
  2. mesh_to_pc.py +58 -0
main.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse, importlib
2
+ import torch
3
+ import time
4
+ import trimesh
5
+ import numpy as np
6
+ from MeshAnything.models.meshanything import MeshAnything
7
+ import datetime
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from accelerate.utils import DistributedDataParallelKwargs
11
+ from safetensors import safe_open
12
+ from mesh_to_pc import process_mesh_to_pc
13
+ from huggingface_hub import hf_hub_download
14
+
15
+ class Dataset:
16
+ def __init__(self, input_type, input_list, mc=False):
17
+ super().__init__()
18
+ self.data = []
19
+ if input_type == 'pc_normal':
20
+ for input_path in input_list:
21
+ # load npy
22
+ cur_data = np.load(input_path)
23
+ # sample 4096
24
+ assert cur_data.shape[0] >= 4096, "input pc_normal should have at least 4096 points"
25
+ idx = np.random.choice(cur_data.shape[0], 4096, replace=False)
26
+ cur_data = cur_data[idx]
27
+ self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]})
28
+
29
+ elif input_type == 'mesh':
30
+ mesh_list = []
31
+ for input_path in input_list:
32
+ # load ply
33
+ cur_data = trimesh.load(input_path)
34
+ mesh_list.append(cur_data)
35
+ if mc:
36
+ print("First Marching Cubes and then sample point cloud, need several minutes...")
37
+ pc_list, _ = process_mesh_to_pc(mesh_list, marching_cubes=mc)
38
+ for input_path, cur_data in zip(input_list, pc_list):
39
+ self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]})
40
+ print(f"dataset total data samples: {len(self.data)}")
41
+
42
+ def __len__(self):
43
+ return len(self.data)
44
+
45
+ def __getitem__(self, idx):
46
+ data_dict = {}
47
+ data_dict['pc_normal'] = self.data[idx]['pc_normal']
48
+ # normalize pc coor
49
+ pc_coor = data_dict['pc_normal'][:, :3]
50
+ normals = data_dict['pc_normal'][:, 3:]
51
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
52
+ pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
53
+ pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
54
+ assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), "normals should be unit vectors, something wrong"
55
+ data_dict['pc_normal'] = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
56
+ data_dict['uid'] = self.data[idx]['uid']
57
+
58
+ return data_dict
59
+
60
+ def get_args():
61
+ parser = argparse.ArgumentParser("MeshAnything", add_help=False)
62
+
63
+ parser.add_argument('--llm', default="facebook/opt-350m", type=str)
64
+ parser.add_argument('--input_dir', default=None, type=str)
65
+ parser.add_argument('--input_path', default=None, type=str)
66
+
67
+ parser.add_argument('--out_dir', default="inference_out", type=str)
68
+ parser.add_argument('--pretrained_weights', default="MeshAnything_350m.pth", type=str)
69
+
70
+ parser.add_argument(
71
+ '--input_type',
72
+ choices=['mesh','pc_normal'],
73
+ default='pc',
74
+ help="Type of the asset to process (default: pc)"
75
+ )
76
+
77
+ parser.add_argument("--codebook_size", default=8192, type=int)
78
+ parser.add_argument("--codebook_dim", default=1024, type=int)
79
+
80
+ parser.add_argument("--n_max_triangles", default=800, type=int)
81
+
82
+ parser.add_argument("--batchsize_per_gpu", default=1, type=int)
83
+ parser.add_argument("--seed", default=0, type=int)
84
+
85
+ parser.add_argument("--mc", default=False, action="store_true")
86
+ parser.add_argument("--sampling", default=False, action="store_true")
87
+
88
+ args = parser.parse_args()
89
+ return args
90
+
91
+ def load_model(args):
92
+ model = MeshAnything(args)
93
+ print("load model over!!!")
94
+
95
+ ckpt_path = hf_hub_download(
96
+ repo_id="Yiwen-ntu/MeshAnything",
97
+ filename="MeshAnything_350m.pth",
98
+ )
99
+ tensors = {}
100
+ with safe_open(ckpt_path, framework="pt", device=0) as f:
101
+ for k in f.keys():
102
+ tensors[k] = f.get_tensor(k)
103
+
104
+ model.load_state_dict(tensors, strict=True)
105
+ print("load weights over!!!")
106
+ return model
107
+ if __name__ == "__main__":
108
+ args = get_args()
109
+
110
+ cur_time = datetime.datetime.now().strftime("%d_%H-%M-%S")
111
+ checkpoint_dir = os.path.join(args.out_dir, cur_time)
112
+ os.makedirs(checkpoint_dir, exist_ok=True)
113
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
114
+ accelerator = Accelerator(
115
+ mixed_precision="fp16",
116
+ project_dir=checkpoint_dir,
117
+ kwargs_handlers=[kwargs]
118
+ )
119
+
120
+ model = load_model(args)
121
+ # create dataset
122
+ if args.input_dir is not None:
123
+ input_list = sorted(os.listdir(args.input_dir))
124
+ # only ply, obj or npy
125
+ if args.input_type == 'pc_normal':
126
+ input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.npy')]
127
+ else:
128
+ input_list = [os.path.join(args.input_dir, x) for x in input_list if x.endswith('.ply') or x.endswith('.obj') or x.endswith('.npy')]
129
+ set_seed(args.seed)
130
+ dataset = Dataset(args.input_type, input_list, args.mc)
131
+ elif args.input_path is not None:
132
+ set_seed(args.seed)
133
+ dataset = Dataset(args.input_type, [args.input_path], args.mc)
134
+ else:
135
+ raise ValueError("input_dir or input_path must be provided.")
136
+
137
+ dataloader = torch.utils.data.DataLoader(
138
+ dataset,
139
+ batch_size=args.batchsize_per_gpu,
140
+ drop_last = False,
141
+ shuffle = False,
142
+ )
143
+
144
+ if accelerator.state.num_processes > 1:
145
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
146
+ dataloader, model = accelerator.prepare(dataloader, model)
147
+ begin_time = time.time()
148
+ print("Generation Start!!!")
149
+ with accelerator.autocast():
150
+ for curr_iter, batch_data_label in enumerate(dataloader):
151
+ curr_time = time.time()
152
+ outputs = model(batch_data_label['pc_normal'], sampling=args.sampling)
153
+ batch_size = outputs.shape[0]
154
+ device = outputs.device
155
+
156
+ for batch_id in range(batch_size):
157
+ recon_mesh = outputs[batch_id]
158
+ recon_mesh = recon_mesh[~torch.isnan(recon_mesh[:, 0, 0])] # nvalid_face x 3 x 3
159
+ vertices = recon_mesh.reshape(-1, 3).cpu()
160
+ vertices_index = np.arange(len(vertices)) # 0, 1, ..., 3 x face
161
+ triangles = vertices_index.reshape(-1, 3)
162
+
163
+ scene_mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, force="mesh",
164
+ merge_primitives=True)
165
+ scene_mesh.merge_vertices()
166
+ scene_mesh.update_faces(scene_mesh.unique_faces())
167
+ scene_mesh.fix_normals()
168
+ save_path = os.path.join(checkpoint_dir, f'{batch_data_label["uid"][batch_id]}_gen.obj')
169
+ num_faces = len(scene_mesh.faces)
170
+ brown_color = np.array([255, 165, 0, 255], dtype=np.uint8)
171
+ face_colors = np.tile(brown_color, (num_faces, 1))
172
+
173
+ scene_mesh.visual.face_colors = face_colors
174
+ scene_mesh.export(save_path)
175
+ print(f"{save_path} Over!!")
176
+ end_time = time.time()
177
+ print(f"Total time: {end_time - begin_time}")
mesh_to_pc.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mesh2sdf.core
2
+ import numpy as np
3
+ import skimage.measure
4
+ import trimesh
5
+
6
+ def normalize_vertices(vertices, scale=0.9):
7
+ bbmin, bbmax = vertices.min(0), vertices.max(0)
8
+ center = (bbmin + bbmax) * 0.5
9
+ scale = 2.0 * scale / (bbmax - bbmin).max()
10
+ vertices = (vertices - center) * scale
11
+ return vertices, center, scale
12
+
13
+ def export_to_watertight(normalized_mesh, octree_depth: int = 7):
14
+ """
15
+ Convert the non-watertight mesh to watertight.
16
+
17
+ Args:
18
+ input_path (str): normalized path
19
+ octree_depth (int):
20
+
21
+ Returns:
22
+ mesh(trimesh.Trimesh): watertight mesh
23
+
24
+ """
25
+ size = 2 ** octree_depth
26
+ level = 2 / size
27
+
28
+ scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
29
+
30
+ sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
31
+
32
+ vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
33
+
34
+ # watertight mesh
35
+ vertices = vertices / size * 2 - 1 # -1 to 1
36
+ vertices = vertices / to_orig_scale + to_orig_center
37
+ # vertices = vertices / to_orig_scale + to_orig_center
38
+ mesh = trimesh.Trimesh(vertices, faces, normals=normals)
39
+
40
+ return mesh
41
+
42
+ def process_mesh_to_pc(mesh_list, marching_cubes = False, sample_num = 4096):
43
+ # mesh_list : list of trimesh
44
+ pc_normal_list = []
45
+ return_mesh_list = []
46
+ for mesh in mesh_list:
47
+ if marching_cubes:
48
+ mesh = export_to_watertight(mesh)
49
+ print("MC over!")
50
+ return_mesh_list.append(mesh)
51
+ points, face_idx = mesh.sample(sample_num, return_index=True)
52
+ normals = mesh.face_normals[face_idx]
53
+
54
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
55
+ pc_normal_list.append(pc_normal)
56
+ print("process mesh success")
57
+ return pc_normal_list, return_mesh_list
58
+