Spaces:
Runtime error
Runtime error
jorgejungle
commited on
Commit
•
2a98cab
1
Parent(s):
29a0098
Update convert.py
Browse files- convert.py +26 -23
convert.py
CHANGED
@@ -183,42 +183,40 @@ class Converter(nn.Module):
|
|
183 |
|
184 |
return color, alpha
|
185 |
|
186 |
-
|
187 |
-
|
188 |
self.opt.output_size = resolution
|
189 |
-
|
190 |
optimizer = torch.optim.Adam([
|
191 |
{'params': self.encoder_density.parameters(), 'lr': 1e-2},
|
192 |
{'params': self.encoder.parameters(), 'lr': 1e-2},
|
193 |
{'params': self.mlp_density.parameters(), 'lr': 1e-3},
|
194 |
{'params': self.mlp.parameters(), 'lr': 1e-3},
|
195 |
])
|
196 |
-
|
197 |
print(f"[INFO] fitting nerf...")
|
198 |
pbar = tqdm.trange(iters)
|
199 |
for i in pbar:
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
|
205 |
-
|
206 |
|
207 |
-
image_gt, alpha_gt = self.render_gs(
|
208 |
-
image_pred, alpha_pred = self.render_nerf(
|
209 |
-
|
210 |
-
# if i % 200 == 0:
|
211 |
-
# kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
|
212 |
|
213 |
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
|
214 |
-
loss = loss_mse
|
215 |
-
|
216 |
loss.backward()
|
217 |
self.encoder_density.grad_total_variation(1e-8)
|
218 |
|
219 |
optimizer.step()
|
220 |
optimizer.zero_grad()
|
221 |
-
|
222 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
223 |
|
224 |
print(f"[INFO] finished fitting nerf!")
|
@@ -266,6 +264,7 @@ class Converter(nn.Module):
|
|
266 |
|
267 |
# init mesh from nerf
|
268 |
grid_size = 256
|
|
|
269 |
sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
|
270 |
|
271 |
S = 128
|
@@ -275,14 +274,18 @@ class Converter(nn.Module):
|
|
275 |
Y = torch.linspace(-1, 1, grid_size).split(S)
|
276 |
Z = torch.linspace(-1, 1, grid_size).split(S)
|
277 |
|
278 |
-
for xi
|
279 |
-
for yi
|
280 |
-
for zi
|
|
|
|
|
|
|
|
|
281 |
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
|
282 |
-
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
283 |
val = self.get_density(pts.to(self.device))
|
284 |
-
sigmas[xi
|
285 |
-
|
286 |
print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
|
287 |
|
288 |
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
|
|
183 |
|
184 |
return color, alpha
|
185 |
|
186 |
+
def fit_nerf(self, iters=512, resolution=128, batch_size=32):
|
|
|
187 |
self.opt.output_size = resolution
|
188 |
+
|
189 |
optimizer = torch.optim.Adam([
|
190 |
{'params': self.encoder_density.parameters(), 'lr': 1e-2},
|
191 |
{'params': self.encoder.parameters(), 'lr': 1e-2},
|
192 |
{'params': self.mlp_density.parameters(), 'lr': 1e-3},
|
193 |
{'params': self.mlp.parameters(), 'lr': 1e-3},
|
194 |
])
|
195 |
+
|
196 |
print(f"[INFO] fitting nerf...")
|
197 |
pbar = tqdm.trange(iters)
|
198 |
for i in pbar:
|
199 |
+
poses = []
|
200 |
+
for _ in range(batch_size):
|
201 |
+
ver = np.random.randint(-45, 45)
|
202 |
+
hor = np.random.randint(-180, 180)
|
203 |
+
rad = np.random.uniform(1.5, 3.0)
|
204 |
+
poses.append(orbit_camera(ver, hor, rad))
|
205 |
|
206 |
+
poses = np.stack(poses)
|
207 |
|
208 |
+
image_gt, alpha_gt = self.render_gs(poses)
|
209 |
+
image_pred, alpha_pred = self.render_nerf(poses)
|
|
|
|
|
|
|
210 |
|
211 |
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
|
212 |
+
loss = loss_mse
|
213 |
+
|
214 |
loss.backward()
|
215 |
self.encoder_density.grad_total_variation(1e-8)
|
216 |
|
217 |
optimizer.step()
|
218 |
optimizer.zero_grad()
|
219 |
+
|
220 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
221 |
|
222 |
print(f"[INFO] finished fitting nerf!")
|
|
|
264 |
|
265 |
# init mesh from nerf
|
266 |
grid_size = 256
|
267 |
+
chunk_size = 64
|
268 |
sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
|
269 |
|
270 |
S = 128
|
|
|
274 |
Y = torch.linspace(-1, 1, grid_size).split(S)
|
275 |
Z = torch.linspace(-1, 1, grid_size).split(S)
|
276 |
|
277 |
+
for xi in range(0, grid_size, chunk_size):
|
278 |
+
for yi in range(0, grid_size, chunk_size):
|
279 |
+
for zi in range(0, grid_size, chunk_size):
|
280 |
+
xs = torch.linspace(-1, 1, chunk_size)
|
281 |
+
ys = torch.linspace(-1, 1, chunk_size)
|
282 |
+
zs = torch.linspace(-1, 1, chunk_size)
|
283 |
+
|
284 |
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
|
285 |
+
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
|
286 |
val = self.get_density(pts.to(self.device))
|
287 |
+
sigmas[xi:xi+chunk_size, yi:yi+chunk_size, zi:zi+chunk_size] = val.reshape(chunk_size, chunk_size, chunk_size).detach().cpu().numpy()
|
288 |
+
|
289 |
print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
|
290 |
|
291 |
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|