Work commited on
Commit
f379155
1 Parent(s): 206b602

add glm conversion code

Browse files
Files changed (5) hide show
  1. Dockerfile +4 -0
  2. app.py +15 -2
  3. convert.py +462 -0
  4. packages.txt +10 -0
  5. requirements.txt +5 -0
Dockerfile CHANGED
@@ -31,6 +31,10 @@ ENV HOME=/home/user \
31
  GRADIO_THEME=huggingface \
32
  SYSTEM=spaces
33
 
 
 
 
 
34
  RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
35
 
36
  # Set the working directory to the user's home directory
 
31
  GRADIO_THEME=huggingface \
32
  SYSTEM=spaces
33
 
34
+ RUN apt-get update && \
35
+ xargs -r -a /code/packages.txt apt-get install -y && \
36
+ rm -rf /var/lib/apt/lists/*
37
+
38
  RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
39
 
40
  # Set the working directory to the user's home directory
app.py CHANGED
@@ -23,8 +23,9 @@ import kiui
23
  from kiui.op import recenter
24
  from kiui.cam import orbit_camera
25
 
26
- from core.options import AllConfigs, Options
27
  from core.models import LGM
 
28
  from mvdream.pipeline_mvdream import MVDreamPipeline
29
 
30
  import spaces
@@ -33,6 +34,7 @@ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
33
  IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
34
  GRADIO_VIDEO_PATH = 'gradio_output.mp4'
35
  GRADIO_PLY_PATH = 'gradio_output.ply'
 
36
 
37
  # opt = tyro.cli(AllConfigs)
38
  opt = Options(
@@ -104,6 +106,7 @@ def process(input_image, prompt, prompt_neg='', input_elevation=0, input_num_ste
104
  os.makedirs(opt.workspace, exist_ok=True)
105
  output_video_path = os.path.join(opt.workspace, GRADIO_VIDEO_PATH)
106
  output_ply_path = os.path.join(opt.workspace, GRADIO_PLY_PATH)
 
107
 
108
  # text-conditioned
109
  if input_image is None:
@@ -190,7 +193,17 @@ def process(input_image, prompt, prompt_neg='', input_elevation=0, input_num_ste
190
  images = np.concatenate(images, axis=0)
191
  imageio.mimwrite(output_video_path, images, fps=30)
192
 
193
- return mv_image_grid, output_video_path, output_ply_path
 
 
 
 
 
 
 
 
 
 
194
 
195
  # gradio UI
196
 
 
23
  from kiui.op import recenter
24
  from kiui.cam import orbit_camera
25
 
26
+ from core.options import AllConfigs, Options, config_defaults
27
  from core.models import LGM
28
+ from convert import Converter
29
  from mvdream.pipeline_mvdream import MVDreamPipeline
30
 
31
  import spaces
 
34
  IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
35
  GRADIO_VIDEO_PATH = 'gradio_output.mp4'
36
  GRADIO_PLY_PATH = 'gradio_output.ply'
37
+ GRADIO_GLB_PATH = 'gradio_output.glb'
38
 
39
  # opt = tyro.cli(AllConfigs)
40
  opt = Options(
 
106
  os.makedirs(opt.workspace, exist_ok=True)
107
  output_video_path = os.path.join(opt.workspace, GRADIO_VIDEO_PATH)
108
  output_ply_path = os.path.join(opt.workspace, GRADIO_PLY_PATH)
109
+ output_glb_path = os.path.join(opt.workspace, GRADIO_GLB_PATH)
110
 
111
  # text-conditioned
112
  if input_image is None:
 
193
  images = np.concatenate(images, axis=0)
194
  imageio.mimwrite(output_video_path, images, fps=30)
195
 
196
+
197
+ # load a saved ply and convert to mesh
198
+ opt.test_path = output_ply_path
199
+
200
+ converter = Converter(opt).cuda()
201
+ converter.fit_nerf()
202
+ converter.fit_mesh()
203
+ converter.fit_mesh_uv()
204
+ converter.export_mesh(opt.test_path.replace('.ply', '.glb'))
205
+
206
+ return mv_image_grid, output_video_path, output_glb_path
207
 
208
  # gradio UI
209
 
convert.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import tyro
4
+ import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from core.options import AllConfigs, Options
11
+ from core.gs import GaussianRenderer
12
+
13
+ import mcubes
14
+ import nerfacc
15
+ import nvdiffrast.torch as dr
16
+
17
+ import kiui
18
+ from kiui.mesh import Mesh
19
+ from kiui.mesh_utils import clean_mesh, decimate_mesh
20
+ from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
21
+ from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
22
+ from kiui.cam import orbit_camera, get_perspective
23
+ from kiui.nn import MLP, trunc_exp
24
+ from kiui.gridencoder import GridEncoder
25
+
26
+ def get_rays(pose, h, w, fovy, opengl=True):
27
+
28
+ x, y = torch.meshgrid(
29
+ torch.arange(w, device=pose.device),
30
+ torch.arange(h, device=pose.device),
31
+ indexing="xy",
32
+ )
33
+ x = x.flatten()
34
+ y = y.flatten()
35
+
36
+ cx = w * 0.5
37
+ cy = h * 0.5
38
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
39
+
40
+ camera_dirs = F.pad(
41
+ torch.stack(
42
+ [
43
+ (x - cx + 0.5) / focal,
44
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
45
+ ],
46
+ dim=-1,
47
+ ),
48
+ (0, 1),
49
+ value=(-1.0 if opengl else 1.0),
50
+ ) # [hw, 3]
51
+
52
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
53
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
54
+
55
+ rays_d = safe_normalize(rays_d)
56
+
57
+ return rays_o, rays_d
58
+
59
+ # Triple renderer of gaussians, gaussian, and diso mesh.
60
+ # gaussian --> nerf --> mesh
61
+ class Converter(nn.Module):
62
+ def __init__(self, opt: Options):
63
+ super().__init__()
64
+
65
+ self.opt = opt
66
+ self.device = torch.device("cuda")
67
+
68
+ # gs renderer
69
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
70
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
71
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
72
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
73
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
74
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
75
+ self.proj_matrix[2, 3] = 1
76
+
77
+ self.gs_renderer = GaussianRenderer(opt)
78
+
79
+ self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
80
+
81
+ # nerf renderer
82
+ if not self.opt.force_cuda_rast:
83
+ self.glctx = dr.RasterizeGLContext()
84
+ else:
85
+ self.glctx = dr.RasterizeCudaContext()
86
+
87
+ self.step = 0
88
+ self.render_step_size = 5e-3
89
+ self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
90
+ self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1)
91
+
92
+ self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum')
93
+ self.encoder = GridEncoder(num_levels=12)
94
+ self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
95
+ self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
96
+
97
+ # mesh renderer
98
+ self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
99
+ self.v = self.f = None
100
+ self.vt = self.ft = None
101
+ self.deform = None
102
+ self.albedo = None
103
+
104
+
105
+ @torch.no_grad()
106
+ def render_gs(self, pose):
107
+
108
+ cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
109
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
110
+
111
+ # cameras needed by gaussian rasterizer
112
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
113
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
114
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
115
+
116
+ out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0))
117
+ image = out['image'].squeeze(1).squeeze(0) # [C, H, W]
118
+ alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W]
119
+
120
+ return image, alpha
121
+
122
+ def get_density(self, xs):
123
+ # xs: [..., 3]
124
+ prefix = xs.shape[:-1]
125
+ xs = xs.view(-1, 3)
126
+ feats = self.encoder_density(xs)
127
+ density = trunc_exp(self.mlp_density(feats))
128
+ density = density.view(*prefix, 1)
129
+ return density
130
+
131
+ def render_nerf(self, pose):
132
+
133
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
134
+
135
+ # get rays
136
+ resolution = self.opt.output_size
137
+ rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
138
+
139
+ # update occ grid
140
+ if self.training:
141
+ def occ_eval_fn(xs):
142
+ sigmas = self.get_density(xs)
143
+ return self.render_step_size * sigmas
144
+
145
+ self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8)
146
+ self.step += 1
147
+
148
+ # render
149
+ def sigma_fn(t_starts, t_ends, ray_indices):
150
+ t_origins = rays_o[ray_indices]
151
+ t_dirs = rays_d[ray_indices]
152
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
153
+ sigmas = self.get_density(xs)
154
+ return sigmas.squeeze(-1)
155
+
156
+ with torch.no_grad():
157
+ ray_indices, t_starts, t_ends = self.estimator.sampling(
158
+ rays_o,
159
+ rays_d,
160
+ sigma_fn=sigma_fn,
161
+ near_plane=0.01,
162
+ far_plane=100,
163
+ render_step_size=self.render_step_size,
164
+ stratified=self.training,
165
+ cone_angle=0,
166
+ )
167
+
168
+ t_origins = rays_o[ray_indices]
169
+ t_dirs = rays_d[ray_indices]
170
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
171
+ sigmas = self.get_density(xs).squeeze(-1)
172
+ rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
173
+
174
+ n_rays=rays_o.shape[0]
175
+ weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays)
176
+ color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays)
177
+ alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays)
178
+
179
+ color = color + 1 * (1.0 - alpha)
180
+
181
+ color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous()
182
+ alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
183
+
184
+ return color, alpha
185
+
186
+ def fit_nerf(self, iters=512, resolution=128):
187
+
188
+ self.opt.output_size = resolution
189
+
190
+ optimizer = torch.optim.Adam([
191
+ {'params': self.encoder_density.parameters(), 'lr': 1e-2},
192
+ {'params': self.encoder.parameters(), 'lr': 1e-2},
193
+ {'params': self.mlp_density.parameters(), 'lr': 1e-3},
194
+ {'params': self.mlp.parameters(), 'lr': 1e-3},
195
+ ])
196
+
197
+ print(f"[INFO] fitting nerf...")
198
+ pbar = tqdm.trange(iters)
199
+ for i in pbar:
200
+
201
+ ver = np.random.randint(-45, 45)
202
+ hor = np.random.randint(-180, 180)
203
+ rad = np.random.uniform(1.5, 3.0)
204
+
205
+ pose = orbit_camera(ver, hor, rad)
206
+
207
+ image_gt, alpha_gt = self.render_gs(pose)
208
+ image_pred, alpha_pred = self.render_nerf(pose)
209
+
210
+ # if i % 200 == 0:
211
+ # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
212
+
213
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
214
+ loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
215
+
216
+ loss.backward()
217
+ self.encoder_density.grad_total_variation(1e-8)
218
+
219
+ optimizer.step()
220
+ optimizer.zero_grad()
221
+
222
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
223
+
224
+ print(f"[INFO] finished fitting nerf!")
225
+
226
+ def render_mesh(self, pose):
227
+
228
+ h = w = self.opt.output_size
229
+
230
+ v = self.v + self.deform
231
+ f = self.f
232
+
233
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
234
+
235
+ # get v_clip and render rgb
236
+ v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
237
+ v_clip = v_cam @ self.proj.T
238
+
239
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
240
+
241
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
242
+ alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients!
243
+
244
+ if self.albedo is None:
245
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
246
+ xyzs = xyzs.view(-1, 3)
247
+ mask = (alpha > 0).view(-1)
248
+ image = torch.zeros_like(xyzs, dtype=torch.float32)
249
+ if mask.any():
250
+ masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1)))
251
+ image[mask] = masked_albedo.float()
252
+ else:
253
+ texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all')
254
+ image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3]
255
+
256
+ image = image.view(1, h, w, 3)
257
+ # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
258
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
259
+ image = alpha * image + (1 - alpha)
260
+
261
+ return image, alpha
262
+
263
+ def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
264
+
265
+ self.opt.output_size = resolution
266
+
267
+ # init mesh from nerf
268
+ grid_size = 256
269
+ sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
270
+
271
+ S = 128
272
+ density_thresh = 10
273
+
274
+ X = torch.linspace(-1, 1, grid_size).split(S)
275
+ Y = torch.linspace(-1, 1, grid_size).split(S)
276
+ Z = torch.linspace(-1, 1, grid_size).split(S)
277
+
278
+ for xi, xs in enumerate(X):
279
+ for yi, ys in enumerate(Y):
280
+ for zi, zs in enumerate(Z):
281
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
282
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
283
+ val = self.get_density(pts.to(self.device))
284
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
285
+
286
+ print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
287
+
288
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
289
+ vertices = vertices / (grid_size - 1.0) * 2 - 1
290
+
291
+ # clean
292
+ vertices = vertices.astype(np.float32)
293
+ triangles = triangles.astype(np.int32)
294
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
295
+ if triangles.shape[0] > decimate_target:
296
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
297
+
298
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
299
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
300
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
301
+
302
+ # fit mesh from gs
303
+ lr_factor = 1
304
+ optimizer = torch.optim.Adam([
305
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
306
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
307
+ {'params': self.deform, 'lr': 1e-4},
308
+ ])
309
+
310
+ print(f"[INFO] fitting mesh...")
311
+ pbar = tqdm.trange(iters)
312
+ for i in pbar:
313
+
314
+ ver = np.random.randint(-10, 10)
315
+ hor = np.random.randint(-180, 180)
316
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
317
+
318
+ pose = orbit_camera(ver, hor, rad)
319
+
320
+ image_gt, alpha_gt = self.render_gs(pose)
321
+ image_pred, alpha_pred = self.render_mesh(pose)
322
+
323
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
324
+ # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
325
+ loss_normal = normal_consistency(self.v + self.deform, self.f)
326
+ loss_offsets = (self.deform ** 2).sum(-1).mean()
327
+ loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
328
+
329
+ loss.backward()
330
+
331
+ optimizer.step()
332
+ optimizer.zero_grad()
333
+
334
+ # remesh periodically
335
+ if i > 0 and i % 512 == 0:
336
+ vertices = (self.v + self.deform).detach().cpu().numpy()
337
+ triangles = self.f.detach().cpu().numpy()
338
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
339
+ if triangles.shape[0] > decimate_target:
340
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
341
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
342
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
343
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
344
+ lr_factor *= 0.5
345
+ optimizer = torch.optim.Adam([
346
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
347
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
348
+ {'params': self.deform, 'lr': 1e-4},
349
+ ])
350
+
351
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
352
+
353
+ # last clean
354
+ vertices = (self.v + self.deform).detach().cpu().numpy()
355
+ triangles = self.f.detach().cpu().numpy()
356
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
357
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
358
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
359
+ self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
360
+
361
+ print(f"[INFO] finished fitting mesh!")
362
+
363
+ # uv mesh refine
364
+ def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2):
365
+
366
+ self.opt.output_size = resolution
367
+
368
+ # unwrap uv
369
+ print(f"[INFO] uv unwrapping...")
370
+ mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
371
+ mesh.auto_normal()
372
+ mesh.auto_uv()
373
+
374
+ self.vt = mesh.vt
375
+ self.ft = mesh.ft
376
+
377
+ # render uv maps
378
+ h = w = texture_resolution
379
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
380
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
381
+
382
+ rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
383
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
384
+ mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
385
+
386
+ # masked query
387
+ xyzs = xyzs.view(-1, 3)
388
+ mask = (mask > 0).view(-1)
389
+
390
+ albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
391
+
392
+ if mask.any():
393
+ print(f"[INFO] querying texture...")
394
+
395
+ xyzs = xyzs[mask] # [M, 3]
396
+
397
+ # batched inference to avoid OOM
398
+ batch = []
399
+ head = 0
400
+ while head < xyzs.shape[0]:
401
+ tail = min(head + 640000, xyzs.shape[0])
402
+ batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
403
+ head += 640000
404
+
405
+ albedo[mask] = torch.cat(batch, dim=0)
406
+
407
+ albedo = albedo.view(h, w, -1)
408
+ mask = mask.view(h, w)
409
+ albedo = uv_padding(albedo, mask, padding)
410
+
411
+ # optimize texture
412
+ self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
413
+
414
+ optimizer = torch.optim.Adam([
415
+ {'params': self.albedo, 'lr': 1e-3},
416
+ ])
417
+
418
+ print(f"[INFO] fitting mesh texture...")
419
+ pbar = tqdm.trange(iters)
420
+ for i in pbar:
421
+
422
+ # shrink to front view as we care more about it...
423
+ ver = np.random.randint(-5, 5)
424
+ hor = np.random.randint(-15, 15)
425
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
426
+
427
+ pose = orbit_camera(ver, hor, rad)
428
+
429
+ image_gt, alpha_gt = self.render_gs(pose)
430
+ image_pred, alpha_pred = self.render_mesh(pose)
431
+
432
+ loss_mse = F.mse_loss(image_pred, image_gt)
433
+ loss = loss_mse
434
+
435
+ loss.backward()
436
+
437
+ optimizer.step()
438
+ optimizer.zero_grad()
439
+
440
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
441
+
442
+ print(f"[INFO] finished fitting mesh texture!")
443
+
444
+
445
+ @torch.no_grad()
446
+ def export_mesh(self, path):
447
+
448
+ mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device)
449
+ mesh.auto_normal()
450
+ mesh.write(path)
451
+
452
+
453
+ # opt = tyro.cli(AllConfigs)
454
+
455
+ # # load a saved ply and convert to mesh
456
+ # assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
457
+
458
+ # converter = Converter(opt).cuda()
459
+ # converter.fit_nerf()
460
+ # converter.fit_mesh()
461
+ # converter.fit_mesh_uv()
462
+ # converter.export_mesh(opt.test_path.replace('.ply', '.glb'))
packages.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ libglvnd0
2
+ libgl1
3
+ libglx0
4
+ libegl1
5
+ libgles2
6
+ libglvnd-dev
7
+ libgl1-mesa-dev
8
+ libegl1-mesa-dev
9
+ libgles2-mesa-dev
10
+ libegl-mesa0
requirements.txt CHANGED
@@ -28,3 +28,8 @@ kiui >= 0.2.3
28
  xatlas
29
  roma
30
  plyfile
 
 
 
 
 
 
28
  xatlas
29
  roma
30
  plyfile
31
+ PyMCubes
32
+ nerfacc
33
+ pymeshlab
34
+ ninja
35
+ git+https://github.com/NVlabs/nvdiffrast.git