Spaces:
Running
on
A10G
Running
on
A10G
Commit
•
9299d52
1
Parent(s):
93f5bda
- app.py +43 -7
- convert.py +197 -109
app.py
CHANGED
@@ -11,14 +11,50 @@ def run(input_ply):
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
22 |
|
23 |
|
24 |
if __name__ == "__main__":
|
|
|
11 |
|
12 |
|
13 |
def main():
|
14 |
+
_TITLE = """LGM Mini"""
|
15 |
+
|
16 |
+
_DESCRIPTION = """Converts Gaussian Splat (.ply) to Mesh (.glb) using [LGM](https://github.com/3DTopia/LGM)."""
|
17 |
+
|
18 |
+
css = """
|
19 |
+
#duplicate-button {
|
20 |
+
margin: auto;
|
21 |
+
color: white;
|
22 |
+
background: #1565c0;
|
23 |
+
border-radius: 100vh;
|
24 |
+
}
|
25 |
+
"""
|
26 |
+
|
27 |
+
demo = gr.Blocks(title=_TITLE, css=css)
|
28 |
+
with demo:
|
29 |
+
gr.DuplicateButton(
|
30 |
+
value="Duplicate Space for private use", elem_id="duplicate-button"
|
31 |
+
)
|
32 |
+
|
33 |
+
with gr.Row():
|
34 |
+
with gr.Column():
|
35 |
+
gr.Markdown("# " + _TITLE)
|
36 |
+
gr.Markdown(_DESCRIPTION)
|
37 |
+
|
38 |
+
with gr.Row(variant="panel"):
|
39 |
+
with gr.Column():
|
40 |
+
input_ply = gr.Model3D(label="Input Splat")
|
41 |
+
button_gen = gr.Button("Generate")
|
42 |
+
|
43 |
+
with gr.Column():
|
44 |
+
output_glb = gr.Model3D(label="Output GLB")
|
45 |
+
|
46 |
+
button_gen.click(run, inputs=[input_ply], outputs=[output_glb])
|
47 |
+
|
48 |
+
gr.Examples(
|
49 |
+
["data_test/catstatue.ply"],
|
50 |
+
inputs=[input_ply],
|
51 |
+
outputs=[output_glb],
|
52 |
+
fn=lambda x: run(x),
|
53 |
+
cache_examples=True,
|
54 |
+
label="Examples",
|
55 |
+
)
|
56 |
|
57 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
58 |
|
59 |
|
60 |
if __name__ == "__main__":
|
convert.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import os
|
3 |
import tyro
|
4 |
import tqdm
|
@@ -23,8 +22,9 @@ from kiui.cam import orbit_camera, get_perspective
|
|
23 |
from kiui.nn import MLP, trunc_exp
|
24 |
from kiui.gridencoder import GridEncoder
|
25 |
|
|
|
26 |
def get_rays(pose, h, w, fovy, opengl=True):
|
27 |
-
|
28 |
x, y = torch.meshgrid(
|
29 |
torch.arange(w, device=pose.device),
|
30 |
torch.arange(h, device=pose.device),
|
@@ -50,12 +50,13 @@ def get_rays(pose, h, w, fovy, opengl=True):
|
|
50 |
) # [hw, 3]
|
51 |
|
52 |
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
53 |
-
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d)
|
54 |
|
55 |
rays_d = safe_normalize(rays_d)
|
56 |
|
57 |
return rays_o, rays_d
|
58 |
|
|
|
59 |
# Triple renderer of gaussians, gaussian, and diso mesh.
|
60 |
# gaussian --> nerf --> mesh
|
61 |
class Converter(nn.Module):
|
@@ -71,7 +72,7 @@ class Converter(nn.Module):
|
|
71 |
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
72 |
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
73 |
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
74 |
-
self.proj_matrix[3, 2] = -
|
75 |
self.proj_matrix[2, 3] = 1
|
76 |
|
77 |
self.gs_renderer = GaussianRenderer(opt)
|
@@ -83,39 +84,49 @@ class Converter(nn.Module):
|
|
83 |
self.glctx = dr.RasterizeGLContext()
|
84 |
else:
|
85 |
self.glctx = dr.RasterizeCudaContext()
|
86 |
-
|
87 |
self.step = 0
|
88 |
self.render_step_size = 5e-3
|
89 |
self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
|
90 |
-
self.estimator = nerfacc.OccGridEstimator(
|
|
|
|
|
91 |
|
92 |
-
self.encoder_density = GridEncoder(
|
|
|
|
|
93 |
self.encoder = GridEncoder(num_levels=12)
|
94 |
self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
|
95 |
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
|
96 |
|
97 |
# mesh renderer
|
98 |
-
self.proj =
|
|
|
|
|
99 |
self.v = self.f = None
|
100 |
self.vt = self.ft = None
|
101 |
self.deform = None
|
102 |
self.albedo = None
|
103 |
-
|
104 |
-
|
105 |
@torch.no_grad()
|
106 |
def render_gs(self, pose):
|
107 |
-
|
108 |
cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
109 |
-
cam_poses[:, :3, 1:3] *= -1
|
110 |
-
|
111 |
# cameras needed by gaussian rasterizer
|
112 |
-
cam_view = torch.inverse(cam_poses).transpose(1, 2)
|
113 |
-
cam_view_proj = cam_view @ self.proj_matrix
|
114 |
-
cam_pos = -
|
115 |
-
|
116 |
-
out = self.gs_renderer.render(
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
return image, alpha
|
121 |
|
@@ -127,22 +138,25 @@ class Converter(nn.Module):
|
|
127 |
density = trunc_exp(self.mlp_density(feats))
|
128 |
density = density.view(*prefix, 1)
|
129 |
return density
|
130 |
-
|
131 |
def render_nerf(self, pose):
|
132 |
-
|
133 |
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
134 |
-
|
135 |
# get rays
|
136 |
resolution = self.opt.output_size
|
137 |
rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
|
138 |
-
|
139 |
# update occ grid
|
140 |
if self.training:
|
|
|
141 |
def occ_eval_fn(xs):
|
142 |
sigmas = self.get_density(xs)
|
143 |
return self.render_step_size * sigmas
|
144 |
-
|
145 |
-
self.estimator.update_every_n_steps(
|
|
|
|
|
146 |
self.step += 1
|
147 |
|
148 |
# render
|
@@ -171,28 +185,41 @@ class Converter(nn.Module):
|
|
171 |
sigmas = self.get_density(xs).squeeze(-1)
|
172 |
rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
|
173 |
|
174 |
-
n_rays=rays_o.shape[0]
|
175 |
-
weights, trans, alphas = nerfacc.render_weight_from_density(
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
color = color + 1 * (1.0 - alpha)
|
180 |
|
181 |
-
color =
|
|
|
|
|
|
|
|
|
|
|
182 |
alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
|
183 |
-
|
184 |
return color, alpha
|
185 |
|
186 |
def fit_nerf(self, iters=512, resolution=128):
|
187 |
|
188 |
self.opt.output_size = resolution
|
189 |
|
190 |
-
optimizer = torch.optim.Adam(
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
196 |
|
197 |
print(f"[INFO] fitting nerf...")
|
198 |
pbar = tqdm.trange(iters)
|
@@ -201,28 +228,30 @@ class Converter(nn.Module):
|
|
201 |
ver = np.random.randint(-45, 45)
|
202 |
hor = np.random.randint(-180, 180)
|
203 |
rad = np.random.uniform(1.5, 3.0)
|
204 |
-
|
205 |
pose = orbit_camera(ver, hor, rad)
|
206 |
-
|
207 |
image_gt, alpha_gt = self.render_gs(pose)
|
208 |
image_pred, alpha_pred = self.render_nerf(pose)
|
209 |
|
210 |
# if i % 200 == 0:
|
211 |
# kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
|
212 |
-
|
213 |
-
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
214 |
-
|
|
|
|
|
215 |
|
216 |
loss.backward()
|
217 |
self.encoder_density.grad_total_variation(1e-8)
|
218 |
-
|
219 |
optimizer.step()
|
220 |
optimizer.zero_grad()
|
221 |
|
222 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
223 |
-
|
224 |
print(f"[INFO] finished fitting nerf!")
|
225 |
-
|
226 |
def render_mesh(self, pose):
|
227 |
|
228 |
h = w = self.opt.output_size
|
@@ -233,29 +262,43 @@ class Converter(nn.Module):
|
|
233 |
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
|
234 |
|
235 |
# get v_clip and render rgb
|
236 |
-
v_cam =
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
v_clip = v_cam @ self.proj.T
|
238 |
|
239 |
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
|
240 |
|
241 |
-
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous()
|
242 |
-
alpha =
|
243 |
-
|
|
|
|
|
244 |
if self.albedo is None:
|
245 |
-
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f)
|
246 |
xyzs = xyzs.view(-1, 3)
|
247 |
mask = (alpha > 0).view(-1)
|
248 |
image = torch.zeros_like(xyzs, dtype=torch.float32)
|
249 |
if mask.any():
|
250 |
-
masked_albedo = torch.sigmoid(
|
|
|
|
|
251 |
image[mask] = masked_albedo.float()
|
252 |
else:
|
253 |
-
texc, texc_db = dr.interpolate(
|
254 |
-
|
|
|
|
|
|
|
|
|
255 |
|
256 |
image = image.view(1, h, w, 3)
|
257 |
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
|
258 |
-
image = image.squeeze(0).permute(2, 0, 1).contiguous()
|
259 |
image = alpha * image + (1 - alpha)
|
260 |
|
261 |
return image, alpha
|
@@ -278,34 +321,51 @@ class Converter(nn.Module):
|
|
278 |
for xi, xs in enumerate(X):
|
279 |
for yi, ys in enumerate(Y):
|
280 |
for zi, zs in enumerate(Z):
|
281 |
-
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing=
|
282 |
-
pts = torch.cat(
|
|
|
|
|
|
|
283 |
val = self.get_density(pts.to(self.device))
|
284 |
-
sigmas[
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
289 |
vertices = vertices / (grid_size - 1.0) * 2 - 1
|
290 |
-
|
291 |
# clean
|
292 |
vertices = vertices.astype(np.float32)
|
293 |
triangles = triangles.astype(np.int32)
|
294 |
-
vertices, triangles = clean_mesh(
|
|
|
|
|
295 |
if triangles.shape[0] > decimate_target:
|
296 |
-
vertices, triangles = decimate_mesh(
|
297 |
-
|
|
|
|
|
298 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
299 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
300 |
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
301 |
|
302 |
# fit mesh from gs
|
303 |
lr_factor = 1
|
304 |
-
optimizer = torch.optim.Adam(
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
309 |
|
310 |
print(f"[INFO] fitting mesh...")
|
311 |
pbar = tqdm.trange(iters)
|
@@ -313,17 +373,19 @@ class Converter(nn.Module):
|
|
313 |
|
314 |
ver = np.random.randint(-10, 10)
|
315 |
hor = np.random.randint(-180, 180)
|
316 |
-
rad = self.opt.cam_radius
|
317 |
|
318 |
pose = orbit_camera(ver, hor, rad)
|
319 |
-
|
320 |
image_gt, alpha_gt = self.render_gs(pose)
|
321 |
image_pred, alpha_pred = self.render_mesh(pose)
|
322 |
|
323 |
-
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
|
|
|
|
324 |
# loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
|
325 |
loss_normal = normal_consistency(self.v + self.deform, self.f)
|
326 |
-
loss_offsets = (self.deform
|
327 |
loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
|
328 |
|
329 |
loss.backward()
|
@@ -335,21 +397,27 @@ class Converter(nn.Module):
|
|
335 |
if i > 0 and i % 512 == 0:
|
336 |
vertices = (self.v + self.deform).detach().cpu().numpy()
|
337 |
triangles = self.f.detach().cpu().numpy()
|
338 |
-
vertices, triangles = clean_mesh(
|
|
|
|
|
339 |
if triangles.shape[0] > decimate_target:
|
340 |
-
vertices, triangles = decimate_mesh(
|
|
|
|
|
341 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
342 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
343 |
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
344 |
lr_factor *= 0.5
|
345 |
-
optimizer = torch.optim.Adam(
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
350 |
|
351 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
352 |
-
|
353 |
# last clean
|
354 |
vertices = (self.v + self.deform).detach().cpu().numpy()
|
355 |
triangles = self.f.detach().cpu().numpy()
|
@@ -357,11 +425,13 @@ class Converter(nn.Module):
|
|
357 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
358 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
359 |
self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
|
360 |
-
|
361 |
print(f"[INFO] finished fitting mesh!")
|
362 |
-
|
363 |
# uv mesh refine
|
364 |
-
def fit_mesh_uv(
|
|
|
|
|
365 |
|
366 |
self.opt.output_size = resolution
|
367 |
|
@@ -376,44 +446,54 @@ class Converter(nn.Module):
|
|
376 |
|
377 |
# render uv maps
|
378 |
h = w = texture_resolution
|
379 |
-
uv = mesh.vt * 2.0 - 1.0
|
380 |
-
uv = torch.cat(
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
xyzs = xyzs.view(-1, 3)
|
388 |
mask = (mask > 0).view(-1)
|
389 |
-
|
390 |
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
|
391 |
|
392 |
if mask.any():
|
393 |
print(f"[INFO] querying texture...")
|
394 |
|
395 |
-
xyzs = xyzs[mask]
|
396 |
|
397 |
# batched inference to avoid OOM
|
398 |
batch = []
|
399 |
head = 0
|
400 |
while head < xyzs.shape[0]:
|
401 |
tail = min(head + 640000, xyzs.shape[0])
|
402 |
-
batch.append(
|
|
|
|
|
403 |
head += 640000
|
404 |
|
405 |
albedo[mask] = torch.cat(batch, dim=0)
|
406 |
-
|
407 |
albedo = albedo.view(h, w, -1)
|
408 |
mask = mask.view(h, w)
|
409 |
albedo = uv_padding(albedo, mask, padding)
|
410 |
|
411 |
# optimize texture
|
412 |
self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
|
413 |
-
|
414 |
-
optimizer = torch.optim.Adam(
|
415 |
-
|
416 |
-
|
|
|
|
|
417 |
|
418 |
print(f"[INFO] fitting mesh texture...")
|
419 |
pbar = tqdm.trange(iters)
|
@@ -422,10 +502,10 @@ class Converter(nn.Module):
|
|
422 |
# shrink to front view as we care more about it...
|
423 |
ver = np.random.randint(-5, 5)
|
424 |
hor = np.random.randint(-15, 15)
|
425 |
-
rad = self.opt.cam_radius
|
426 |
-
|
427 |
pose = orbit_camera(ver, hor, rad)
|
428 |
-
|
429 |
image_gt, alpha_gt = self.render_gs(pose)
|
430 |
image_pred, alpha_pred = self.render_mesh(pose)
|
431 |
|
@@ -438,14 +518,20 @@ class Converter(nn.Module):
|
|
438 |
optimizer.zero_grad()
|
439 |
|
440 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
441 |
-
|
442 |
-
print(f"[INFO] finished fitting mesh texture!")
|
443 |
|
|
|
444 |
|
445 |
@torch.no_grad()
|
446 |
def export_mesh(self, path):
|
447 |
-
|
448 |
-
mesh = Mesh(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
mesh.auto_normal()
|
450 |
mesh.write(path)
|
451 |
|
@@ -453,10 +539,12 @@ class Converter(nn.Module):
|
|
453 |
opt = tyro.cli(AllConfigs)
|
454 |
|
455 |
# load a saved ply and convert to mesh
|
456 |
-
assert opt.test_path.endswith(
|
|
|
|
|
457 |
|
458 |
converter = Converter(opt).cuda()
|
459 |
converter.fit_nerf()
|
460 |
converter.fit_mesh()
|
461 |
converter.fit_mesh_uv()
|
462 |
-
converter.export_mesh(opt.test_path.replace(
|
|
|
|
|
1 |
import os
|
2 |
import tyro
|
3 |
import tqdm
|
|
|
22 |
from kiui.nn import MLP, trunc_exp
|
23 |
from kiui.gridencoder import GridEncoder
|
24 |
|
25 |
+
|
26 |
def get_rays(pose, h, w, fovy, opengl=True):
|
27 |
+
|
28 |
x, y = torch.meshgrid(
|
29 |
torch.arange(w, device=pose.device),
|
30 |
torch.arange(h, device=pose.device),
|
|
|
50 |
) # [hw, 3]
|
51 |
|
52 |
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
|
53 |
+
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
|
54 |
|
55 |
rays_d = safe_normalize(rays_d)
|
56 |
|
57 |
return rays_o, rays_d
|
58 |
|
59 |
+
|
60 |
# Triple renderer of gaussians, gaussian, and diso mesh.
|
61 |
# gaussian --> nerf --> mesh
|
62 |
class Converter(nn.Module):
|
|
|
72 |
self.proj_matrix[0, 0] = 1 / self.tan_half_fov
|
73 |
self.proj_matrix[1, 1] = 1 / self.tan_half_fov
|
74 |
self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
|
75 |
+
self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
|
76 |
self.proj_matrix[2, 3] = 1
|
77 |
|
78 |
self.gs_renderer = GaussianRenderer(opt)
|
|
|
84 |
self.glctx = dr.RasterizeGLContext()
|
85 |
else:
|
86 |
self.glctx = dr.RasterizeCudaContext()
|
87 |
+
|
88 |
self.step = 0
|
89 |
self.render_step_size = 5e-3
|
90 |
self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
|
91 |
+
self.estimator = nerfacc.OccGridEstimator(
|
92 |
+
roi_aabb=self.aabb, resolution=64, levels=1
|
93 |
+
)
|
94 |
|
95 |
+
self.encoder_density = GridEncoder(
|
96 |
+
num_levels=12
|
97 |
+
) # VMEncoder(output_dim=16, mode='sum')
|
98 |
self.encoder = GridEncoder(num_levels=12)
|
99 |
self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
|
100 |
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
|
101 |
|
102 |
# mesh renderer
|
103 |
+
self.proj = (
|
104 |
+
torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
|
105 |
+
)
|
106 |
self.v = self.f = None
|
107 |
self.vt = self.ft = None
|
108 |
self.deform = None
|
109 |
self.albedo = None
|
110 |
+
|
|
|
111 |
@torch.no_grad()
|
112 |
def render_gs(self, pose):
|
113 |
+
|
114 |
cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
|
115 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
116 |
+
|
117 |
# cameras needed by gaussian rasterizer
|
118 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
119 |
+
cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
|
120 |
+
cam_pos = -cam_poses[:, :3, 3] # [V, 3]
|
121 |
+
|
122 |
+
out = self.gs_renderer.render(
|
123 |
+
self.gaussians.unsqueeze(0),
|
124 |
+
cam_view.unsqueeze(0),
|
125 |
+
cam_view_proj.unsqueeze(0),
|
126 |
+
cam_pos.unsqueeze(0),
|
127 |
+
)
|
128 |
+
image = out["image"].squeeze(1).squeeze(0) # [C, H, W]
|
129 |
+
alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W]
|
130 |
|
131 |
return image, alpha
|
132 |
|
|
|
138 |
density = trunc_exp(self.mlp_density(feats))
|
139 |
density = density.view(*prefix, 1)
|
140 |
return density
|
141 |
+
|
142 |
def render_nerf(self, pose):
|
143 |
+
|
144 |
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
|
145 |
+
|
146 |
# get rays
|
147 |
resolution = self.opt.output_size
|
148 |
rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
|
149 |
+
|
150 |
# update occ grid
|
151 |
if self.training:
|
152 |
+
|
153 |
def occ_eval_fn(xs):
|
154 |
sigmas = self.get_density(xs)
|
155 |
return self.render_step_size * sigmas
|
156 |
+
|
157 |
+
self.estimator.update_every_n_steps(
|
158 |
+
self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8
|
159 |
+
)
|
160 |
self.step += 1
|
161 |
|
162 |
# render
|
|
|
185 |
sigmas = self.get_density(xs).squeeze(-1)
|
186 |
rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
|
187 |
|
188 |
+
n_rays = rays_o.shape[0]
|
189 |
+
weights, trans, alphas = nerfacc.render_weight_from_density(
|
190 |
+
t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays
|
191 |
+
)
|
192 |
+
color = nerfacc.accumulate_along_rays(
|
193 |
+
weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays
|
194 |
+
)
|
195 |
+
alpha = nerfacc.accumulate_along_rays(
|
196 |
+
weights, values=None, ray_indices=ray_indices, n_rays=n_rays
|
197 |
+
)
|
198 |
|
199 |
color = color + 1 * (1.0 - alpha)
|
200 |
|
201 |
+
color = (
|
202 |
+
color.view(resolution, resolution, 3)
|
203 |
+
.clamp(0, 1)
|
204 |
+
.permute(2, 0, 1)
|
205 |
+
.contiguous()
|
206 |
+
)
|
207 |
alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
|
208 |
+
|
209 |
return color, alpha
|
210 |
|
211 |
def fit_nerf(self, iters=512, resolution=128):
|
212 |
|
213 |
self.opt.output_size = resolution
|
214 |
|
215 |
+
optimizer = torch.optim.Adam(
|
216 |
+
[
|
217 |
+
{"params": self.encoder_density.parameters(), "lr": 1e-2},
|
218 |
+
{"params": self.encoder.parameters(), "lr": 1e-2},
|
219 |
+
{"params": self.mlp_density.parameters(), "lr": 1e-3},
|
220 |
+
{"params": self.mlp.parameters(), "lr": 1e-3},
|
221 |
+
]
|
222 |
+
)
|
223 |
|
224 |
print(f"[INFO] fitting nerf...")
|
225 |
pbar = tqdm.trange(iters)
|
|
|
228 |
ver = np.random.randint(-45, 45)
|
229 |
hor = np.random.randint(-180, 180)
|
230 |
rad = np.random.uniform(1.5, 3.0)
|
231 |
+
|
232 |
pose = orbit_camera(ver, hor, rad)
|
233 |
+
|
234 |
image_gt, alpha_gt = self.render_gs(pose)
|
235 |
image_pred, alpha_pred = self.render_nerf(pose)
|
236 |
|
237 |
# if i % 200 == 0:
|
238 |
# kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
|
239 |
+
|
240 |
+
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
241 |
+
alpha_pred, alpha_gt
|
242 |
+
)
|
243 |
+
loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
|
244 |
|
245 |
loss.backward()
|
246 |
self.encoder_density.grad_total_variation(1e-8)
|
247 |
+
|
248 |
optimizer.step()
|
249 |
optimizer.zero_grad()
|
250 |
|
251 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
252 |
+
|
253 |
print(f"[INFO] finished fitting nerf!")
|
254 |
+
|
255 |
def render_mesh(self, pose):
|
256 |
|
257 |
h = w = self.opt.output_size
|
|
|
262 |
pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
|
263 |
|
264 |
# get v_clip and render rgb
|
265 |
+
v_cam = (
|
266 |
+
torch.matmul(
|
267 |
+
F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T
|
268 |
+
)
|
269 |
+
.float()
|
270 |
+
.unsqueeze(0)
|
271 |
+
)
|
272 |
v_clip = v_cam @ self.proj.T
|
273 |
|
274 |
rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
|
275 |
|
276 |
+
alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
|
277 |
+
alpha = (
|
278 |
+
dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0)
|
279 |
+
) # [H, W] important to enable gradients!
|
280 |
+
|
281 |
if self.albedo is None:
|
282 |
+
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
|
283 |
xyzs = xyzs.view(-1, 3)
|
284 |
mask = (alpha > 0).view(-1)
|
285 |
image = torch.zeros_like(xyzs, dtype=torch.float32)
|
286 |
if mask.any():
|
287 |
+
masked_albedo = torch.sigmoid(
|
288 |
+
self.mlp(self.encoder(xyzs[mask].detach(), bound=1))
|
289 |
+
)
|
290 |
image[mask] = masked_albedo.float()
|
291 |
else:
|
292 |
+
texc, texc_db = dr.interpolate(
|
293 |
+
self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all"
|
294 |
+
)
|
295 |
+
image = torch.sigmoid(
|
296 |
+
dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)
|
297 |
+
) # [1, H, W, 3]
|
298 |
|
299 |
image = image.view(1, h, w, 3)
|
300 |
# image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
|
301 |
+
image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
|
302 |
image = alpha * image + (1 - alpha)
|
303 |
|
304 |
return image, alpha
|
|
|
321 |
for xi, xs in enumerate(X):
|
322 |
for yi, ys in enumerate(Y):
|
323 |
for zi, zs in enumerate(Z):
|
324 |
+
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
|
325 |
+
pts = torch.cat(
|
326 |
+
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
|
327 |
+
dim=-1,
|
328 |
+
) # [S, 3]
|
329 |
val = self.get_density(pts.to(self.device))
|
330 |
+
sigmas[
|
331 |
+
xi * S : xi * S + len(xs),
|
332 |
+
yi * S : yi * S + len(ys),
|
333 |
+
zi * S : zi * S + len(zs),
|
334 |
+
] = (
|
335 |
+
val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
|
336 |
+
) # [S, 1] --> [x, y, z]
|
337 |
+
|
338 |
+
print(
|
339 |
+
f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})"
|
340 |
+
)
|
341 |
|
342 |
vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
|
343 |
vertices = vertices / (grid_size - 1.0) * 2 - 1
|
344 |
+
|
345 |
# clean
|
346 |
vertices = vertices.astype(np.float32)
|
347 |
triangles = triangles.astype(np.int32)
|
348 |
+
vertices, triangles = clean_mesh(
|
349 |
+
vertices, triangles, remesh=True, remesh_size=0.01
|
350 |
+
)
|
351 |
if triangles.shape[0] > decimate_target:
|
352 |
+
vertices, triangles = decimate_mesh(
|
353 |
+
vertices, triangles, decimate_target, optimalplacement=False
|
354 |
+
)
|
355 |
+
|
356 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
357 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
358 |
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
359 |
|
360 |
# fit mesh from gs
|
361 |
lr_factor = 1
|
362 |
+
optimizer = torch.optim.Adam(
|
363 |
+
[
|
364 |
+
{"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
|
365 |
+
{"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
|
366 |
+
{"params": self.deform, "lr": 1e-4},
|
367 |
+
]
|
368 |
+
)
|
369 |
|
370 |
print(f"[INFO] fitting mesh...")
|
371 |
pbar = tqdm.trange(iters)
|
|
|
373 |
|
374 |
ver = np.random.randint(-10, 10)
|
375 |
hor = np.random.randint(-180, 180)
|
376 |
+
rad = self.opt.cam_radius # np.random.uniform(1, 2)
|
377 |
|
378 |
pose = orbit_camera(ver, hor, rad)
|
379 |
+
|
380 |
image_gt, alpha_gt = self.render_gs(pose)
|
381 |
image_pred, alpha_pred = self.render_mesh(pose)
|
382 |
|
383 |
+
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
|
384 |
+
alpha_pred, alpha_gt
|
385 |
+
)
|
386 |
# loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
|
387 |
loss_normal = normal_consistency(self.v + self.deform, self.f)
|
388 |
+
loss_offsets = (self.deform**2).sum(-1).mean()
|
389 |
loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
|
390 |
|
391 |
loss.backward()
|
|
|
397 |
if i > 0 and i % 512 == 0:
|
398 |
vertices = (self.v + self.deform).detach().cpu().numpy()
|
399 |
triangles = self.f.detach().cpu().numpy()
|
400 |
+
vertices, triangles = clean_mesh(
|
401 |
+
vertices, triangles, remesh=True, remesh_size=0.01
|
402 |
+
)
|
403 |
if triangles.shape[0] > decimate_target:
|
404 |
+
vertices, triangles = decimate_mesh(
|
405 |
+
vertices, triangles, decimate_target, optimalplacement=False
|
406 |
+
)
|
407 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
408 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
409 |
self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
|
410 |
lr_factor *= 0.5
|
411 |
+
optimizer = torch.optim.Adam(
|
412 |
+
[
|
413 |
+
{"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
|
414 |
+
{"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
|
415 |
+
{"params": self.deform, "lr": 1e-4},
|
416 |
+
]
|
417 |
+
)
|
418 |
|
419 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
420 |
+
|
421 |
# last clean
|
422 |
vertices = (self.v + self.deform).detach().cpu().numpy()
|
423 |
triangles = self.f.detach().cpu().numpy()
|
|
|
425 |
self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
|
426 |
self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
|
427 |
self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
|
428 |
+
|
429 |
print(f"[INFO] finished fitting mesh!")
|
430 |
+
|
431 |
# uv mesh refine
|
432 |
+
def fit_mesh_uv(
|
433 |
+
self, iters=512, resolution=512, texture_resolution=1024, padding=2
|
434 |
+
):
|
435 |
|
436 |
self.opt.output_size = resolution
|
437 |
|
|
|
446 |
|
447 |
# render uv maps
|
448 |
h = w = texture_resolution
|
449 |
+
uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
|
450 |
+
uv = torch.cat(
|
451 |
+
(uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1
|
452 |
+
) # [N, 4]
|
453 |
+
|
454 |
+
rast, _ = dr.rasterize(
|
455 |
+
self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)
|
456 |
+
) # [1, h, w, 4]
|
457 |
+
xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
|
458 |
+
mask, _ = dr.interpolate(
|
459 |
+
torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f
|
460 |
+
) # [1, h, w, 1]
|
461 |
+
|
462 |
+
# masked query
|
463 |
xyzs = xyzs.view(-1, 3)
|
464 |
mask = (mask > 0).view(-1)
|
465 |
+
|
466 |
albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
|
467 |
|
468 |
if mask.any():
|
469 |
print(f"[INFO] querying texture...")
|
470 |
|
471 |
+
xyzs = xyzs[mask] # [M, 3]
|
472 |
|
473 |
# batched inference to avoid OOM
|
474 |
batch = []
|
475 |
head = 0
|
476 |
while head < xyzs.shape[0]:
|
477 |
tail = min(head + 640000, xyzs.shape[0])
|
478 |
+
batch.append(
|
479 |
+
torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()
|
480 |
+
)
|
481 |
head += 640000
|
482 |
|
483 |
albedo[mask] = torch.cat(batch, dim=0)
|
484 |
+
|
485 |
albedo = albedo.view(h, w, -1)
|
486 |
mask = mask.view(h, w)
|
487 |
albedo = uv_padding(albedo, mask, padding)
|
488 |
|
489 |
# optimize texture
|
490 |
self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
|
491 |
+
|
492 |
+
optimizer = torch.optim.Adam(
|
493 |
+
[
|
494 |
+
{"params": self.albedo, "lr": 1e-3},
|
495 |
+
]
|
496 |
+
)
|
497 |
|
498 |
print(f"[INFO] fitting mesh texture...")
|
499 |
pbar = tqdm.trange(iters)
|
|
|
502 |
# shrink to front view as we care more about it...
|
503 |
ver = np.random.randint(-5, 5)
|
504 |
hor = np.random.randint(-15, 15)
|
505 |
+
rad = self.opt.cam_radius # np.random.uniform(1, 2)
|
506 |
+
|
507 |
pose = orbit_camera(ver, hor, rad)
|
508 |
+
|
509 |
image_gt, alpha_gt = self.render_gs(pose)
|
510 |
image_pred, alpha_pred = self.render_mesh(pose)
|
511 |
|
|
|
518 |
optimizer.zero_grad()
|
519 |
|
520 |
pbar.set_description(f"MSE = {loss_mse.item():.6f}")
|
|
|
|
|
521 |
|
522 |
+
print(f"[INFO] finished fitting mesh texture!")
|
523 |
|
524 |
@torch.no_grad()
|
525 |
def export_mesh(self, path):
|
526 |
+
|
527 |
+
mesh = Mesh(
|
528 |
+
v=self.v,
|
529 |
+
f=self.f,
|
530 |
+
vt=self.vt,
|
531 |
+
ft=self.ft,
|
532 |
+
albedo=torch.sigmoid(self.albedo),
|
533 |
+
device=self.device,
|
534 |
+
)
|
535 |
mesh.auto_normal()
|
536 |
mesh.write(path)
|
537 |
|
|
|
539 |
opt = tyro.cli(AllConfigs)
|
540 |
|
541 |
# load a saved ply and convert to mesh
|
542 |
+
assert opt.test_path.endswith(
|
543 |
+
".ply"
|
544 |
+
), "--test_path must be a .ply file saved by infer.py"
|
545 |
|
546 |
converter = Converter(opt).cuda()
|
547 |
converter.fit_nerf()
|
548 |
converter.fit_mesh()
|
549 |
converter.fit_mesh_uv()
|
550 |
+
converter.export_mesh(opt.test_path.replace(".ply", ".glb"))
|