jorgejungle commited on
Commit
2a98cab
1 Parent(s): 29a0098

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +26 -23
convert.py CHANGED
@@ -183,42 +183,40 @@ class Converter(nn.Module):
183
 
184
  return color, alpha
185
 
186
- def fit_nerf(self, iters=512, resolution=128):
187
-
188
  self.opt.output_size = resolution
189
-
190
  optimizer = torch.optim.Adam([
191
  {'params': self.encoder_density.parameters(), 'lr': 1e-2},
192
  {'params': self.encoder.parameters(), 'lr': 1e-2},
193
  {'params': self.mlp_density.parameters(), 'lr': 1e-3},
194
  {'params': self.mlp.parameters(), 'lr': 1e-3},
195
  ])
196
-
197
  print(f"[INFO] fitting nerf...")
198
  pbar = tqdm.trange(iters)
199
  for i in pbar:
200
-
201
- ver = np.random.randint(-45, 45)
202
- hor = np.random.randint(-180, 180)
203
- rad = np.random.uniform(1.5, 3.0)
 
 
204
 
205
- pose = orbit_camera(ver, hor, rad)
206
 
207
- image_gt, alpha_gt = self.render_gs(pose)
208
- image_pred, alpha_pred = self.render_nerf(pose)
209
-
210
- # if i % 200 == 0:
211
- # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
212
 
213
  loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
214
- loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
215
-
216
  loss.backward()
217
  self.encoder_density.grad_total_variation(1e-8)
218
 
219
  optimizer.step()
220
  optimizer.zero_grad()
221
-
222
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
223
 
224
  print(f"[INFO] finished fitting nerf!")
@@ -266,6 +264,7 @@ class Converter(nn.Module):
266
 
267
  # init mesh from nerf
268
  grid_size = 256
 
269
  sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
270
 
271
  S = 128
@@ -275,14 +274,18 @@ class Converter(nn.Module):
275
  Y = torch.linspace(-1, 1, grid_size).split(S)
276
  Z = torch.linspace(-1, 1, grid_size).split(S)
277
 
278
- for xi, xs in enumerate(X):
279
- for yi, ys in enumerate(Y):
280
- for zi, zs in enumerate(Z):
 
 
 
 
281
  xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
282
- pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
283
  val = self.get_density(pts.to(self.device))
284
- sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
285
-
286
  print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
287
 
288
  vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
 
183
 
184
  return color, alpha
185
 
186
+ def fit_nerf(self, iters=512, resolution=128, batch_size=32):
 
187
  self.opt.output_size = resolution
188
+
189
  optimizer = torch.optim.Adam([
190
  {'params': self.encoder_density.parameters(), 'lr': 1e-2},
191
  {'params': self.encoder.parameters(), 'lr': 1e-2},
192
  {'params': self.mlp_density.parameters(), 'lr': 1e-3},
193
  {'params': self.mlp.parameters(), 'lr': 1e-3},
194
  ])
195
+
196
  print(f"[INFO] fitting nerf...")
197
  pbar = tqdm.trange(iters)
198
  for i in pbar:
199
+ poses = []
200
+ for _ in range(batch_size):
201
+ ver = np.random.randint(-45, 45)
202
+ hor = np.random.randint(-180, 180)
203
+ rad = np.random.uniform(1.5, 3.0)
204
+ poses.append(orbit_camera(ver, hor, rad))
205
 
206
+ poses = np.stack(poses)
207
 
208
+ image_gt, alpha_gt = self.render_gs(poses)
209
+ image_pred, alpha_pred = self.render_nerf(poses)
 
 
 
210
 
211
  loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
212
+ loss = loss_mse
213
+
214
  loss.backward()
215
  self.encoder_density.grad_total_variation(1e-8)
216
 
217
  optimizer.step()
218
  optimizer.zero_grad()
219
+
220
  pbar.set_description(f"MSE = {loss_mse.item():.6f}")
221
 
222
  print(f"[INFO] finished fitting nerf!")
 
264
 
265
  # init mesh from nerf
266
  grid_size = 256
267
+ chunk_size = 64
268
  sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
269
 
270
  S = 128
 
274
  Y = torch.linspace(-1, 1, grid_size).split(S)
275
  Z = torch.linspace(-1, 1, grid_size).split(S)
276
 
277
+ for xi in range(0, grid_size, chunk_size):
278
+ for yi in range(0, grid_size, chunk_size):
279
+ for zi in range(0, grid_size, chunk_size):
280
+ xs = torch.linspace(-1, 1, chunk_size)
281
+ ys = torch.linspace(-1, 1, chunk_size)
282
+ zs = torch.linspace(-1, 1, chunk_size)
283
+
284
  xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
285
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
286
  val = self.get_density(pts.to(self.device))
287
+ sigmas[xi:xi+chunk_size, yi:yi+chunk_size, zi:zi+chunk_size] = val.reshape(chunk_size, chunk_size, chunk_size).detach().cpu().numpy()
288
+
289
  print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
290
 
291
  vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)