dylanebert HF staff commited on
Commit
9299d52
1 Parent(s): 93f5bda
Files changed (2) hide show
  1. app.py +43 -7
  2. convert.py +197 -109
app.py CHANGED
@@ -11,14 +11,50 @@ def run(input_ply):
11
 
12
 
13
  def main():
14
- demo = gr.Interface(
15
- fn=run,
16
- inputs=gr.Model3D(label="Input Splat"),
17
- outputs=gr.Model3D(label="Output GLB"),
18
- examples=
19
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- demo.launch(server_name="0.0.0.0", server_port=7860)
22
 
23
 
24
  if __name__ == "__main__":
 
11
 
12
 
13
  def main():
14
+ _TITLE = """LGM Mini"""
15
+
16
+ _DESCRIPTION = """Converts Gaussian Splat (.ply) to Mesh (.glb) using [LGM](https://github.com/3DTopia/LGM)."""
17
+
18
+ css = """
19
+ #duplicate-button {
20
+ margin: auto;
21
+ color: white;
22
+ background: #1565c0;
23
+ border-radius: 100vh;
24
+ }
25
+ """
26
+
27
+ demo = gr.Blocks(title=_TITLE, css=css)
28
+ with demo:
29
+ gr.DuplicateButton(
30
+ value="Duplicate Space for private use", elem_id="duplicate-button"
31
+ )
32
+
33
+ with gr.Row():
34
+ with gr.Column():
35
+ gr.Markdown("# " + _TITLE)
36
+ gr.Markdown(_DESCRIPTION)
37
+
38
+ with gr.Row(variant="panel"):
39
+ with gr.Column():
40
+ input_ply = gr.Model3D(label="Input Splat")
41
+ button_gen = gr.Button("Generate")
42
+
43
+ with gr.Column():
44
+ output_glb = gr.Model3D(label="Output GLB")
45
+
46
+ button_gen.click(run, inputs=[input_ply], outputs=[output_glb])
47
+
48
+ gr.Examples(
49
+ ["data_test/catstatue.ply"],
50
+ inputs=[input_ply],
51
+ outputs=[output_glb],
52
+ fn=lambda x: run(x),
53
+ cache_examples=True,
54
+ label="Examples",
55
+ )
56
 
57
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
58
 
59
 
60
  if __name__ == "__main__":
convert.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import tyro
4
  import tqdm
@@ -23,8 +22,9 @@ from kiui.cam import orbit_camera, get_perspective
23
  from kiui.nn import MLP, trunc_exp
24
  from kiui.gridencoder import GridEncoder
25
 
 
26
  def get_rays(pose, h, w, fovy, opengl=True):
27
-
28
  x, y = torch.meshgrid(
29
  torch.arange(w, device=pose.device),
30
  torch.arange(h, device=pose.device),
@@ -50,12 +50,13 @@ def get_rays(pose, h, w, fovy, opengl=True):
50
  ) # [hw, 3]
51
 
52
  rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
53
- rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
54
 
55
  rays_d = safe_normalize(rays_d)
56
 
57
  return rays_o, rays_d
58
 
 
59
  # Triple renderer of gaussians, gaussian, and diso mesh.
60
  # gaussian --> nerf --> mesh
61
  class Converter(nn.Module):
@@ -71,7 +72,7 @@ class Converter(nn.Module):
71
  self.proj_matrix[0, 0] = 1 / self.tan_half_fov
72
  self.proj_matrix[1, 1] = 1 / self.tan_half_fov
73
  self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
74
- self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
75
  self.proj_matrix[2, 3] = 1
76
 
77
  self.gs_renderer = GaussianRenderer(opt)
@@ -83,39 +84,49 @@ class Converter(nn.Module):
83
  self.glctx = dr.RasterizeGLContext()
84
  else:
85
  self.glctx = dr.RasterizeCudaContext()
86
-
87
  self.step = 0
88
  self.render_step_size = 5e-3
89
  self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
90
- self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1)
 
 
91
 
92
- self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum')
 
 
93
  self.encoder = GridEncoder(num_levels=12)
94
  self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
95
  self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
96
 
97
  # mesh renderer
98
- self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
 
 
99
  self.v = self.f = None
100
  self.vt = self.ft = None
101
  self.deform = None
102
  self.albedo = None
103
-
104
-
105
  @torch.no_grad()
106
  def render_gs(self, pose):
107
-
108
  cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
109
- cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
110
-
111
  # cameras needed by gaussian rasterizer
112
- cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
113
- cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
114
- cam_pos = - cam_poses[:, :3, 3] # [V, 3]
115
-
116
- out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0))
117
- image = out['image'].squeeze(1).squeeze(0) # [C, H, W]
118
- alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W]
 
 
 
 
 
119
 
120
  return image, alpha
121
 
@@ -127,22 +138,25 @@ class Converter(nn.Module):
127
  density = trunc_exp(self.mlp_density(feats))
128
  density = density.view(*prefix, 1)
129
  return density
130
-
131
  def render_nerf(self, pose):
132
-
133
  pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
134
-
135
  # get rays
136
  resolution = self.opt.output_size
137
  rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
138
-
139
  # update occ grid
140
  if self.training:
 
141
  def occ_eval_fn(xs):
142
  sigmas = self.get_density(xs)
143
  return self.render_step_size * sigmas
144
-
145
- self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8)
 
 
146
  self.step += 1
147
 
148
  # render
@@ -171,28 +185,41 @@ class Converter(nn.Module):
171
  sigmas = self.get_density(xs).squeeze(-1)
172
  rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
173
 
174
- n_rays=rays_o.shape[0]
175
- weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays)
176
- color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays)
177
- alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays)
 
 
 
 
 
 
178
 
179
  color = color + 1 * (1.0 - alpha)
180
 
181
- color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous()
 
 
 
 
 
182
  alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
183
-
184
  return color, alpha
185
 
186
  def fit_nerf(self, iters=512, resolution=128):
187
 
188
  self.opt.output_size = resolution
189
 
190
- optimizer = torch.optim.Adam([
191
- {'params': self.encoder_density.parameters(), 'lr': 1e-2},
192
- {'params': self.encoder.parameters(), 'lr': 1e-2},
193
- {'params': self.mlp_density.parameters(), 'lr': 1e-3},
194
- {'params': self.mlp.parameters(), 'lr': 1e-3},
195
- ])
 
 
196
 
197
  print(f"[INFO] fitting nerf...")
198
  pbar = tqdm.trange(iters)
@@ -201,28 +228,30 @@ class Converter(nn.Module):
201
  ver = np.random.randint(-45, 45)
202
  hor = np.random.randint(-180, 180)
203
  rad = np.random.uniform(1.5, 3.0)
204
-
205
  pose = orbit_camera(ver, hor, rad)
206
-
207
  image_gt, alpha_gt = self.render_gs(pose)
208
  image_pred, alpha_pred = self.render_nerf(pose)
209
 
210
  # if i % 200 == 0:
211
  # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
212
-
213
- loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
214
- loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
 
 
215
 
216
  loss.backward()
217
  self.encoder_density.grad_total_variation(1e-8)
218
-
219
  optimizer.step()
220
  optimizer.zero_grad()
221
 
222
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
223
-
224
  print(f"[INFO] finished fitting nerf!")
225
-
226
  def render_mesh(self, pose):
227
 
228
  h = w = self.opt.output_size
@@ -233,29 +262,43 @@ class Converter(nn.Module):
233
  pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
234
 
235
  # get v_clip and render rgb
236
- v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
 
 
 
 
 
 
237
  v_clip = v_cam @ self.proj.T
238
 
239
  rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
240
 
241
- alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
242
- alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients!
243
-
 
 
244
  if self.albedo is None:
245
- xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
246
  xyzs = xyzs.view(-1, 3)
247
  mask = (alpha > 0).view(-1)
248
  image = torch.zeros_like(xyzs, dtype=torch.float32)
249
  if mask.any():
250
- masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1)))
 
 
251
  image[mask] = masked_albedo.float()
252
  else:
253
- texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all')
254
- image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3]
 
 
 
 
255
 
256
  image = image.view(1, h, w, 3)
257
  # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
258
- image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
259
  image = alpha * image + (1 - alpha)
260
 
261
  return image, alpha
@@ -278,34 +321,51 @@ class Converter(nn.Module):
278
  for xi, xs in enumerate(X):
279
  for yi, ys in enumerate(Y):
280
  for zi, zs in enumerate(Z):
281
- xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
282
- pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
 
 
 
283
  val = self.get_density(pts.to(self.device))
284
- sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
285
-
286
- print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
 
 
 
 
 
 
 
 
287
 
288
  vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
289
  vertices = vertices / (grid_size - 1.0) * 2 - 1
290
-
291
  # clean
292
  vertices = vertices.astype(np.float32)
293
  triangles = triangles.astype(np.int32)
294
- vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
 
 
295
  if triangles.shape[0] > decimate_target:
296
- vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
297
-
 
 
298
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
299
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
300
  self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
301
 
302
  # fit mesh from gs
303
  lr_factor = 1
304
- optimizer = torch.optim.Adam([
305
- {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
306
- {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
307
- {'params': self.deform, 'lr': 1e-4},
308
- ])
 
 
309
 
310
  print(f"[INFO] fitting mesh...")
311
  pbar = tqdm.trange(iters)
@@ -313,17 +373,19 @@ class Converter(nn.Module):
313
 
314
  ver = np.random.randint(-10, 10)
315
  hor = np.random.randint(-180, 180)
316
- rad = self.opt.cam_radius # np.random.uniform(1, 2)
317
 
318
  pose = orbit_camera(ver, hor, rad)
319
-
320
  image_gt, alpha_gt = self.render_gs(pose)
321
  image_pred, alpha_pred = self.render_mesh(pose)
322
 
323
- loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
 
 
324
  # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
325
  loss_normal = normal_consistency(self.v + self.deform, self.f)
326
- loss_offsets = (self.deform ** 2).sum(-1).mean()
327
  loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
328
 
329
  loss.backward()
@@ -335,21 +397,27 @@ class Converter(nn.Module):
335
  if i > 0 and i % 512 == 0:
336
  vertices = (self.v + self.deform).detach().cpu().numpy()
337
  triangles = self.f.detach().cpu().numpy()
338
- vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
 
 
339
  if triangles.shape[0] > decimate_target:
340
- vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
 
 
341
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
342
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
343
  self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
344
  lr_factor *= 0.5
345
- optimizer = torch.optim.Adam([
346
- {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
347
- {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
348
- {'params': self.deform, 'lr': 1e-4},
349
- ])
 
 
350
 
351
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
352
-
353
  # last clean
354
  vertices = (self.v + self.deform).detach().cpu().numpy()
355
  triangles = self.f.detach().cpu().numpy()
@@ -357,11 +425,13 @@ class Converter(nn.Module):
357
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
358
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
359
  self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
360
-
361
  print(f"[INFO] finished fitting mesh!")
362
-
363
  # uv mesh refine
364
- def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2):
 
 
365
 
366
  self.opt.output_size = resolution
367
 
@@ -376,44 +446,54 @@ class Converter(nn.Module):
376
 
377
  # render uv maps
378
  h = w = texture_resolution
379
- uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
380
- uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
381
-
382
- rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
383
- xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
384
- mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
385
-
386
- # masked query
 
 
 
 
 
 
387
  xyzs = xyzs.view(-1, 3)
388
  mask = (mask > 0).view(-1)
389
-
390
  albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
391
 
392
  if mask.any():
393
  print(f"[INFO] querying texture...")
394
 
395
- xyzs = xyzs[mask] # [M, 3]
396
 
397
  # batched inference to avoid OOM
398
  batch = []
399
  head = 0
400
  while head < xyzs.shape[0]:
401
  tail = min(head + 640000, xyzs.shape[0])
402
- batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
 
 
403
  head += 640000
404
 
405
  albedo[mask] = torch.cat(batch, dim=0)
406
-
407
  albedo = albedo.view(h, w, -1)
408
  mask = mask.view(h, w)
409
  albedo = uv_padding(albedo, mask, padding)
410
 
411
  # optimize texture
412
  self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
413
-
414
- optimizer = torch.optim.Adam([
415
- {'params': self.albedo, 'lr': 1e-3},
416
- ])
 
 
417
 
418
  print(f"[INFO] fitting mesh texture...")
419
  pbar = tqdm.trange(iters)
@@ -422,10 +502,10 @@ class Converter(nn.Module):
422
  # shrink to front view as we care more about it...
423
  ver = np.random.randint(-5, 5)
424
  hor = np.random.randint(-15, 15)
425
- rad = self.opt.cam_radius # np.random.uniform(1, 2)
426
-
427
  pose = orbit_camera(ver, hor, rad)
428
-
429
  image_gt, alpha_gt = self.render_gs(pose)
430
  image_pred, alpha_pred = self.render_mesh(pose)
431
 
@@ -438,14 +518,20 @@ class Converter(nn.Module):
438
  optimizer.zero_grad()
439
 
440
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
441
-
442
- print(f"[INFO] finished fitting mesh texture!")
443
 
 
444
 
445
  @torch.no_grad()
446
  def export_mesh(self, path):
447
-
448
- mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device)
 
 
 
 
 
 
 
449
  mesh.auto_normal()
450
  mesh.write(path)
451
 
@@ -453,10 +539,12 @@ class Converter(nn.Module):
453
  opt = tyro.cli(AllConfigs)
454
 
455
  # load a saved ply and convert to mesh
456
- assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
 
 
457
 
458
  converter = Converter(opt).cuda()
459
  converter.fit_nerf()
460
  converter.fit_mesh()
461
  converter.fit_mesh_uv()
462
- converter.export_mesh(opt.test_path.replace('.ply', '.glb'))
 
 
1
  import os
2
  import tyro
3
  import tqdm
 
22
  from kiui.nn import MLP, trunc_exp
23
  from kiui.gridencoder import GridEncoder
24
 
25
+
26
  def get_rays(pose, h, w, fovy, opengl=True):
27
+
28
  x, y = torch.meshgrid(
29
  torch.arange(w, device=pose.device),
30
  torch.arange(h, device=pose.device),
 
50
  ) # [hw, 3]
51
 
52
  rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
53
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
54
 
55
  rays_d = safe_normalize(rays_d)
56
 
57
  return rays_o, rays_d
58
 
59
+
60
  # Triple renderer of gaussians, gaussian, and diso mesh.
61
  # gaussian --> nerf --> mesh
62
  class Converter(nn.Module):
 
72
  self.proj_matrix[0, 0] = 1 / self.tan_half_fov
73
  self.proj_matrix[1, 1] = 1 / self.tan_half_fov
74
  self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
75
+ self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
76
  self.proj_matrix[2, 3] = 1
77
 
78
  self.gs_renderer = GaussianRenderer(opt)
 
84
  self.glctx = dr.RasterizeGLContext()
85
  else:
86
  self.glctx = dr.RasterizeCudaContext()
87
+
88
  self.step = 0
89
  self.render_step_size = 5e-3
90
  self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
91
+ self.estimator = nerfacc.OccGridEstimator(
92
+ roi_aabb=self.aabb, resolution=64, levels=1
93
+ )
94
 
95
+ self.encoder_density = GridEncoder(
96
+ num_levels=12
97
+ ) # VMEncoder(output_dim=16, mode='sum')
98
  self.encoder = GridEncoder(num_levels=12)
99
  self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
100
  self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
101
 
102
  # mesh renderer
103
+ self.proj = (
104
+ torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
105
+ )
106
  self.v = self.f = None
107
  self.vt = self.ft = None
108
  self.deform = None
109
  self.albedo = None
110
+
 
111
  @torch.no_grad()
112
  def render_gs(self, pose):
113
+
114
  cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
115
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
116
+
117
  # cameras needed by gaussian rasterizer
118
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
119
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
120
+ cam_pos = -cam_poses[:, :3, 3] # [V, 3]
121
+
122
+ out = self.gs_renderer.render(
123
+ self.gaussians.unsqueeze(0),
124
+ cam_view.unsqueeze(0),
125
+ cam_view_proj.unsqueeze(0),
126
+ cam_pos.unsqueeze(0),
127
+ )
128
+ image = out["image"].squeeze(1).squeeze(0) # [C, H, W]
129
+ alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W]
130
 
131
  return image, alpha
132
 
 
138
  density = trunc_exp(self.mlp_density(feats))
139
  density = density.view(*prefix, 1)
140
  return density
141
+
142
  def render_nerf(self, pose):
143
+
144
  pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
145
+
146
  # get rays
147
  resolution = self.opt.output_size
148
  rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
149
+
150
  # update occ grid
151
  if self.training:
152
+
153
  def occ_eval_fn(xs):
154
  sigmas = self.get_density(xs)
155
  return self.render_step_size * sigmas
156
+
157
+ self.estimator.update_every_n_steps(
158
+ self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8
159
+ )
160
  self.step += 1
161
 
162
  # render
 
185
  sigmas = self.get_density(xs).squeeze(-1)
186
  rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
187
 
188
+ n_rays = rays_o.shape[0]
189
+ weights, trans, alphas = nerfacc.render_weight_from_density(
190
+ t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays
191
+ )
192
+ color = nerfacc.accumulate_along_rays(
193
+ weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays
194
+ )
195
+ alpha = nerfacc.accumulate_along_rays(
196
+ weights, values=None, ray_indices=ray_indices, n_rays=n_rays
197
+ )
198
 
199
  color = color + 1 * (1.0 - alpha)
200
 
201
+ color = (
202
+ color.view(resolution, resolution, 3)
203
+ .clamp(0, 1)
204
+ .permute(2, 0, 1)
205
+ .contiguous()
206
+ )
207
  alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
208
+
209
  return color, alpha
210
 
211
  def fit_nerf(self, iters=512, resolution=128):
212
 
213
  self.opt.output_size = resolution
214
 
215
+ optimizer = torch.optim.Adam(
216
+ [
217
+ {"params": self.encoder_density.parameters(), "lr": 1e-2},
218
+ {"params": self.encoder.parameters(), "lr": 1e-2},
219
+ {"params": self.mlp_density.parameters(), "lr": 1e-3},
220
+ {"params": self.mlp.parameters(), "lr": 1e-3},
221
+ ]
222
+ )
223
 
224
  print(f"[INFO] fitting nerf...")
225
  pbar = tqdm.trange(iters)
 
228
  ver = np.random.randint(-45, 45)
229
  hor = np.random.randint(-180, 180)
230
  rad = np.random.uniform(1.5, 3.0)
231
+
232
  pose = orbit_camera(ver, hor, rad)
233
+
234
  image_gt, alpha_gt = self.render_gs(pose)
235
  image_pred, alpha_pred = self.render_nerf(pose)
236
 
237
  # if i % 200 == 0:
238
  # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
239
+
240
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
241
+ alpha_pred, alpha_gt
242
+ )
243
+ loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
244
 
245
  loss.backward()
246
  self.encoder_density.grad_total_variation(1e-8)
247
+
248
  optimizer.step()
249
  optimizer.zero_grad()
250
 
251
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
252
+
253
  print(f"[INFO] finished fitting nerf!")
254
+
255
  def render_mesh(self, pose):
256
 
257
  h = w = self.opt.output_size
 
262
  pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
263
 
264
  # get v_clip and render rgb
265
+ v_cam = (
266
+ torch.matmul(
267
+ F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T
268
+ )
269
+ .float()
270
+ .unsqueeze(0)
271
+ )
272
  v_clip = v_cam @ self.proj.T
273
 
274
  rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
275
 
276
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
277
+ alpha = (
278
+ dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0)
279
+ ) # [H, W] important to enable gradients!
280
+
281
  if self.albedo is None:
282
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
283
  xyzs = xyzs.view(-1, 3)
284
  mask = (alpha > 0).view(-1)
285
  image = torch.zeros_like(xyzs, dtype=torch.float32)
286
  if mask.any():
287
+ masked_albedo = torch.sigmoid(
288
+ self.mlp(self.encoder(xyzs[mask].detach(), bound=1))
289
+ )
290
  image[mask] = masked_albedo.float()
291
  else:
292
+ texc, texc_db = dr.interpolate(
293
+ self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all"
294
+ )
295
+ image = torch.sigmoid(
296
+ dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)
297
+ ) # [1, H, W, 3]
298
 
299
  image = image.view(1, h, w, 3)
300
  # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
301
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
302
  image = alpha * image + (1 - alpha)
303
 
304
  return image, alpha
 
321
  for xi, xs in enumerate(X):
322
  for yi, ys in enumerate(Y):
323
  for zi, zs in enumerate(Z):
324
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
325
+ pts = torch.cat(
326
+ [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
327
+ dim=-1,
328
+ ) # [S, 3]
329
  val = self.get_density(pts.to(self.device))
330
+ sigmas[
331
+ xi * S : xi * S + len(xs),
332
+ yi * S : yi * S + len(ys),
333
+ zi * S : zi * S + len(zs),
334
+ ] = (
335
+ val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
336
+ ) # [S, 1] --> [x, y, z]
337
+
338
+ print(
339
+ f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})"
340
+ )
341
 
342
  vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
343
  vertices = vertices / (grid_size - 1.0) * 2 - 1
344
+
345
  # clean
346
  vertices = vertices.astype(np.float32)
347
  triangles = triangles.astype(np.int32)
348
+ vertices, triangles = clean_mesh(
349
+ vertices, triangles, remesh=True, remesh_size=0.01
350
+ )
351
  if triangles.shape[0] > decimate_target:
352
+ vertices, triangles = decimate_mesh(
353
+ vertices, triangles, decimate_target, optimalplacement=False
354
+ )
355
+
356
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
357
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
358
  self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
359
 
360
  # fit mesh from gs
361
  lr_factor = 1
362
+ optimizer = torch.optim.Adam(
363
+ [
364
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
365
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
366
+ {"params": self.deform, "lr": 1e-4},
367
+ ]
368
+ )
369
 
370
  print(f"[INFO] fitting mesh...")
371
  pbar = tqdm.trange(iters)
 
373
 
374
  ver = np.random.randint(-10, 10)
375
  hor = np.random.randint(-180, 180)
376
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
377
 
378
  pose = orbit_camera(ver, hor, rad)
379
+
380
  image_gt, alpha_gt = self.render_gs(pose)
381
  image_pred, alpha_pred = self.render_mesh(pose)
382
 
383
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
384
+ alpha_pred, alpha_gt
385
+ )
386
  # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
387
  loss_normal = normal_consistency(self.v + self.deform, self.f)
388
+ loss_offsets = (self.deform**2).sum(-1).mean()
389
  loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
390
 
391
  loss.backward()
 
397
  if i > 0 and i % 512 == 0:
398
  vertices = (self.v + self.deform).detach().cpu().numpy()
399
  triangles = self.f.detach().cpu().numpy()
400
+ vertices, triangles = clean_mesh(
401
+ vertices, triangles, remesh=True, remesh_size=0.01
402
+ )
403
  if triangles.shape[0] > decimate_target:
404
+ vertices, triangles = decimate_mesh(
405
+ vertices, triangles, decimate_target, optimalplacement=False
406
+ )
407
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
408
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
409
  self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
410
  lr_factor *= 0.5
411
+ optimizer = torch.optim.Adam(
412
+ [
413
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
414
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
415
+ {"params": self.deform, "lr": 1e-4},
416
+ ]
417
+ )
418
 
419
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
420
+
421
  # last clean
422
  vertices = (self.v + self.deform).detach().cpu().numpy()
423
  triangles = self.f.detach().cpu().numpy()
 
425
  self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
426
  self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
427
  self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
428
+
429
  print(f"[INFO] finished fitting mesh!")
430
+
431
  # uv mesh refine
432
+ def fit_mesh_uv(
433
+ self, iters=512, resolution=512, texture_resolution=1024, padding=2
434
+ ):
435
 
436
  self.opt.output_size = resolution
437
 
 
446
 
447
  # render uv maps
448
  h = w = texture_resolution
449
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
450
+ uv = torch.cat(
451
+ (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1
452
+ ) # [N, 4]
453
+
454
+ rast, _ = dr.rasterize(
455
+ self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)
456
+ ) # [1, h, w, 4]
457
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
458
+ mask, _ = dr.interpolate(
459
+ torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f
460
+ ) # [1, h, w, 1]
461
+
462
+ # masked query
463
  xyzs = xyzs.view(-1, 3)
464
  mask = (mask > 0).view(-1)
465
+
466
  albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
467
 
468
  if mask.any():
469
  print(f"[INFO] querying texture...")
470
 
471
+ xyzs = xyzs[mask] # [M, 3]
472
 
473
  # batched inference to avoid OOM
474
  batch = []
475
  head = 0
476
  while head < xyzs.shape[0]:
477
  tail = min(head + 640000, xyzs.shape[0])
478
+ batch.append(
479
+ torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()
480
+ )
481
  head += 640000
482
 
483
  albedo[mask] = torch.cat(batch, dim=0)
484
+
485
  albedo = albedo.view(h, w, -1)
486
  mask = mask.view(h, w)
487
  albedo = uv_padding(albedo, mask, padding)
488
 
489
  # optimize texture
490
  self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
491
+
492
+ optimizer = torch.optim.Adam(
493
+ [
494
+ {"params": self.albedo, "lr": 1e-3},
495
+ ]
496
+ )
497
 
498
  print(f"[INFO] fitting mesh texture...")
499
  pbar = tqdm.trange(iters)
 
502
  # shrink to front view as we care more about it...
503
  ver = np.random.randint(-5, 5)
504
  hor = np.random.randint(-15, 15)
505
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
506
+
507
  pose = orbit_camera(ver, hor, rad)
508
+
509
  image_gt, alpha_gt = self.render_gs(pose)
510
  image_pred, alpha_pred = self.render_mesh(pose)
511
 
 
518
  optimizer.zero_grad()
519
 
520
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
 
 
521
 
522
+ print(f"[INFO] finished fitting mesh texture!")
523
 
524
  @torch.no_grad()
525
  def export_mesh(self, path):
526
+
527
+ mesh = Mesh(
528
+ v=self.v,
529
+ f=self.f,
530
+ vt=self.vt,
531
+ ft=self.ft,
532
+ albedo=torch.sigmoid(self.albedo),
533
+ device=self.device,
534
+ )
535
  mesh.auto_normal()
536
  mesh.write(path)
537
 
 
539
  opt = tyro.cli(AllConfigs)
540
 
541
  # load a saved ply and convert to mesh
542
+ assert opt.test_path.endswith(
543
+ ".ply"
544
+ ), "--test_path must be a .ply file saved by infer.py"
545
 
546
  converter = Converter(opt).cuda()
547
  converter.fit_nerf()
548
  converter.fit_mesh()
549
  converter.fit_mesh_uv()
550
+ converter.export_mesh(opt.test_path.replace(".ply", ".glb"))