dylanebert HF staff commited on
Commit
93f5bda
1 Parent(s): 02afac0

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data_test/catstatue.ply filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv/
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nvidia/cuda:12.1.1-devel-ubuntu22.04
2
+
3
+ # Set the environment variable
4
+ ENV DEBIAN_FRONTEND=noninteractive
5
+
6
+ # Install the required packages
7
+ RUN apt-get update && apt-get install -y \
8
+ software-properties-common
9
+
10
+ # Add the deadsnakes PPA
11
+ RUN add-apt-repository ppa:deadsnakes/ppa
12
+
13
+ # Install Python 3.10
14
+ RUN apt-get update && apt-get install -y \
15
+ python3.10 \
16
+ python3.10-dev \
17
+ python3.10-distutils \
18
+ python3.10-venv \
19
+ python3-pip
20
+
21
+ # Install other dependencies
22
+ RUN apt-get install -y \
23
+ git \
24
+ gcc \
25
+ g++ \
26
+ libgl1 \
27
+ libglib2.0.0 \
28
+ ffmpeg \
29
+ cmake \
30
+ libgtk2.0.0
31
+
32
+ # Working directory
33
+ WORKDIR /app
34
+
35
+ COPY requirements.txt .
36
+
37
+ # Install the required Python packages
38
+ RUN python3.10 -m pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu121
39
+
40
+ # Copy all files to the working directory
41
+ COPY . .
42
+
43
+ EXPOSE 7860
44
+
45
+ # Run the gradio app
46
+ CMD ["python3.10", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 3D Topia
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+
4
+
5
+ def run(input_ply):
6
+ subprocess.run(
7
+ "python3.10 convert.py big --force-cuda-rast --test_path " + input_ply,
8
+ shell=True,
9
+ )
10
+ return input_ply.replace(".ply", ".glb")
11
+
12
+
13
+ def main():
14
+ demo = gr.Interface(
15
+ fn=run,
16
+ inputs=gr.Model3D(label="Input Splat"),
17
+ outputs=gr.Model3D(label="Output GLB"),
18
+ examples=
19
+ )
20
+
21
+ demo.launch(server_name="0.0.0.0", server_port=7860)
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main()
convert.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import tyro
4
+ import tqdm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from core.options import AllConfigs, Options
11
+ from core.gs import GaussianRenderer
12
+
13
+ import mcubes
14
+ import nerfacc
15
+ import nvdiffrast.torch as dr
16
+
17
+ import kiui
18
+ from kiui.mesh import Mesh
19
+ from kiui.mesh_utils import clean_mesh, decimate_mesh
20
+ from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
21
+ from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
22
+ from kiui.cam import orbit_camera, get_perspective
23
+ from kiui.nn import MLP, trunc_exp
24
+ from kiui.gridencoder import GridEncoder
25
+
26
+ def get_rays(pose, h, w, fovy, opengl=True):
27
+
28
+ x, y = torch.meshgrid(
29
+ torch.arange(w, device=pose.device),
30
+ torch.arange(h, device=pose.device),
31
+ indexing="xy",
32
+ )
33
+ x = x.flatten()
34
+ y = y.flatten()
35
+
36
+ cx = w * 0.5
37
+ cy = h * 0.5
38
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
39
+
40
+ camera_dirs = F.pad(
41
+ torch.stack(
42
+ [
43
+ (x - cx + 0.5) / focal,
44
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
45
+ ],
46
+ dim=-1,
47
+ ),
48
+ (0, 1),
49
+ value=(-1.0 if opengl else 1.0),
50
+ ) # [hw, 3]
51
+
52
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
53
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
54
+
55
+ rays_d = safe_normalize(rays_d)
56
+
57
+ return rays_o, rays_d
58
+
59
+ # Triple renderer of gaussians, gaussian, and diso mesh.
60
+ # gaussian --> nerf --> mesh
61
+ class Converter(nn.Module):
62
+ def __init__(self, opt: Options):
63
+ super().__init__()
64
+
65
+ self.opt = opt
66
+ self.device = torch.device("cuda")
67
+
68
+ # gs renderer
69
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
70
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
71
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
72
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
73
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
74
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
75
+ self.proj_matrix[2, 3] = 1
76
+
77
+ self.gs_renderer = GaussianRenderer(opt)
78
+
79
+ self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
80
+
81
+ # nerf renderer
82
+ if not self.opt.force_cuda_rast:
83
+ self.glctx = dr.RasterizeGLContext()
84
+ else:
85
+ self.glctx = dr.RasterizeCudaContext()
86
+
87
+ self.step = 0
88
+ self.render_step_size = 5e-3
89
+ self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
90
+ self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1)
91
+
92
+ self.encoder_density = GridEncoder(num_levels=12) # VMEncoder(output_dim=16, mode='sum')
93
+ self.encoder = GridEncoder(num_levels=12)
94
+ self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
95
+ self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
96
+
97
+ # mesh renderer
98
+ self.proj = torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
99
+ self.v = self.f = None
100
+ self.vt = self.ft = None
101
+ self.deform = None
102
+ self.albedo = None
103
+
104
+
105
+ @torch.no_grad()
106
+ def render_gs(self, pose):
107
+
108
+ cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
109
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
110
+
111
+ # cameras needed by gaussian rasterizer
112
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
113
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
114
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
115
+
116
+ out = self.gs_renderer.render(self.gaussians.unsqueeze(0), cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0))
117
+ image = out['image'].squeeze(1).squeeze(0) # [C, H, W]
118
+ alpha = out['alpha'].squeeze(2).squeeze(1).squeeze(0) # [H, W]
119
+
120
+ return image, alpha
121
+
122
+ def get_density(self, xs):
123
+ # xs: [..., 3]
124
+ prefix = xs.shape[:-1]
125
+ xs = xs.view(-1, 3)
126
+ feats = self.encoder_density(xs)
127
+ density = trunc_exp(self.mlp_density(feats))
128
+ density = density.view(*prefix, 1)
129
+ return density
130
+
131
+ def render_nerf(self, pose):
132
+
133
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
134
+
135
+ # get rays
136
+ resolution = self.opt.output_size
137
+ rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
138
+
139
+ # update occ grid
140
+ if self.training:
141
+ def occ_eval_fn(xs):
142
+ sigmas = self.get_density(xs)
143
+ return self.render_step_size * sigmas
144
+
145
+ self.estimator.update_every_n_steps(self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8)
146
+ self.step += 1
147
+
148
+ # render
149
+ def sigma_fn(t_starts, t_ends, ray_indices):
150
+ t_origins = rays_o[ray_indices]
151
+ t_dirs = rays_d[ray_indices]
152
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
153
+ sigmas = self.get_density(xs)
154
+ return sigmas.squeeze(-1)
155
+
156
+ with torch.no_grad():
157
+ ray_indices, t_starts, t_ends = self.estimator.sampling(
158
+ rays_o,
159
+ rays_d,
160
+ sigma_fn=sigma_fn,
161
+ near_plane=0.01,
162
+ far_plane=100,
163
+ render_step_size=self.render_step_size,
164
+ stratified=self.training,
165
+ cone_angle=0,
166
+ )
167
+
168
+ t_origins = rays_o[ray_indices]
169
+ t_dirs = rays_d[ray_indices]
170
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
171
+ sigmas = self.get_density(xs).squeeze(-1)
172
+ rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
173
+
174
+ n_rays=rays_o.shape[0]
175
+ weights, trans, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays)
176
+ color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays)
177
+ alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays)
178
+
179
+ color = color + 1 * (1.0 - alpha)
180
+
181
+ color = color.view(resolution, resolution, 3).clamp(0, 1).permute(2, 0, 1).contiguous()
182
+ alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
183
+
184
+ return color, alpha
185
+
186
+ def fit_nerf(self, iters=512, resolution=128):
187
+
188
+ self.opt.output_size = resolution
189
+
190
+ optimizer = torch.optim.Adam([
191
+ {'params': self.encoder_density.parameters(), 'lr': 1e-2},
192
+ {'params': self.encoder.parameters(), 'lr': 1e-2},
193
+ {'params': self.mlp_density.parameters(), 'lr': 1e-3},
194
+ {'params': self.mlp.parameters(), 'lr': 1e-3},
195
+ ])
196
+
197
+ print(f"[INFO] fitting nerf...")
198
+ pbar = tqdm.trange(iters)
199
+ for i in pbar:
200
+
201
+ ver = np.random.randint(-45, 45)
202
+ hor = np.random.randint(-180, 180)
203
+ rad = np.random.uniform(1.5, 3.0)
204
+
205
+ pose = orbit_camera(ver, hor, rad)
206
+
207
+ image_gt, alpha_gt = self.render_gs(pose)
208
+ image_pred, alpha_pred = self.render_nerf(pose)
209
+
210
+ # if i % 200 == 0:
211
+ # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
212
+
213
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
214
+ loss = loss_mse #+ 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
215
+
216
+ loss.backward()
217
+ self.encoder_density.grad_total_variation(1e-8)
218
+
219
+ optimizer.step()
220
+ optimizer.zero_grad()
221
+
222
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
223
+
224
+ print(f"[INFO] finished fitting nerf!")
225
+
226
+ def render_mesh(self, pose):
227
+
228
+ h = w = self.opt.output_size
229
+
230
+ v = self.v + self.deform
231
+ f = self.f
232
+
233
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
234
+
235
+ # get v_clip and render rgb
236
+ v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
237
+ v_clip = v_cam @ self.proj.T
238
+
239
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
240
+
241
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
242
+ alpha = dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) # [H, W] important to enable gradients!
243
+
244
+ if self.albedo is None:
245
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
246
+ xyzs = xyzs.view(-1, 3)
247
+ mask = (alpha > 0).view(-1)
248
+ image = torch.zeros_like(xyzs, dtype=torch.float32)
249
+ if mask.any():
250
+ masked_albedo = torch.sigmoid(self.mlp(self.encoder(xyzs[mask].detach(), bound=1)))
251
+ image[mask] = masked_albedo.float()
252
+ else:
253
+ texc, texc_db = dr.interpolate(self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs='all')
254
+ image = torch.sigmoid(dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)) # [1, H, W, 3]
255
+
256
+ image = image.view(1, h, w, 3)
257
+ # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
258
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
259
+ image = alpha * image + (1 - alpha)
260
+
261
+ return image, alpha
262
+
263
+ def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
264
+
265
+ self.opt.output_size = resolution
266
+
267
+ # init mesh from nerf
268
+ grid_size = 256
269
+ sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
270
+
271
+ S = 128
272
+ density_thresh = 10
273
+
274
+ X = torch.linspace(-1, 1, grid_size).split(S)
275
+ Y = torch.linspace(-1, 1, grid_size).split(S)
276
+ Z = torch.linspace(-1, 1, grid_size).split(S)
277
+
278
+ for xi, xs in enumerate(X):
279
+ for yi, ys in enumerate(Y):
280
+ for zi, zs in enumerate(Z):
281
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
282
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3]
283
+ val = self.get_density(pts.to(self.device))
284
+ sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z]
285
+
286
+ print(f'[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})')
287
+
288
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
289
+ vertices = vertices / (grid_size - 1.0) * 2 - 1
290
+
291
+ # clean
292
+ vertices = vertices.astype(np.float32)
293
+ triangles = triangles.astype(np.int32)
294
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
295
+ if triangles.shape[0] > decimate_target:
296
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
297
+
298
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
299
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
300
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
301
+
302
+ # fit mesh from gs
303
+ lr_factor = 1
304
+ optimizer = torch.optim.Adam([
305
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
306
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
307
+ {'params': self.deform, 'lr': 1e-4},
308
+ ])
309
+
310
+ print(f"[INFO] fitting mesh...")
311
+ pbar = tqdm.trange(iters)
312
+ for i in pbar:
313
+
314
+ ver = np.random.randint(-10, 10)
315
+ hor = np.random.randint(-180, 180)
316
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
317
+
318
+ pose = orbit_camera(ver, hor, rad)
319
+
320
+ image_gt, alpha_gt = self.render_gs(pose)
321
+ image_pred, alpha_pred = self.render_mesh(pose)
322
+
323
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt)
324
+ # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
325
+ loss_normal = normal_consistency(self.v + self.deform, self.f)
326
+ loss_offsets = (self.deform ** 2).sum(-1).mean()
327
+ loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
328
+
329
+ loss.backward()
330
+
331
+ optimizer.step()
332
+ optimizer.zero_grad()
333
+
334
+ # remesh periodically
335
+ if i > 0 and i % 512 == 0:
336
+ vertices = (self.v + self.deform).detach().cpu().numpy()
337
+ triangles = self.f.detach().cpu().numpy()
338
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
339
+ if triangles.shape[0] > decimate_target:
340
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
341
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
342
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
343
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
344
+ lr_factor *= 0.5
345
+ optimizer = torch.optim.Adam([
346
+ {'params': self.encoder.parameters(), 'lr': 1e-3 * lr_factor},
347
+ {'params': self.mlp.parameters(), 'lr': 1e-3 * lr_factor},
348
+ {'params': self.deform, 'lr': 1e-4},
349
+ ])
350
+
351
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
352
+
353
+ # last clean
354
+ vertices = (self.v + self.deform).detach().cpu().numpy()
355
+ triangles = self.f.detach().cpu().numpy()
356
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
357
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
358
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
359
+ self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
360
+
361
+ print(f"[INFO] finished fitting mesh!")
362
+
363
+ # uv mesh refine
364
+ def fit_mesh_uv(self, iters=512, resolution=512, texture_resolution=1024, padding=2):
365
+
366
+ self.opt.output_size = resolution
367
+
368
+ # unwrap uv
369
+ print(f"[INFO] uv unwrapping...")
370
+ mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
371
+ mesh.auto_normal()
372
+ mesh.auto_uv()
373
+
374
+ self.vt = mesh.vt
375
+ self.ft = mesh.ft
376
+
377
+ # render uv maps
378
+ h = w = texture_resolution
379
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
380
+ uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4]
381
+
382
+ rast, _ = dr.rasterize(self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)) # [1, h, w, 4]
383
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
384
+ mask, _ = dr.interpolate(torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f) # [1, h, w, 1]
385
+
386
+ # masked query
387
+ xyzs = xyzs.view(-1, 3)
388
+ mask = (mask > 0).view(-1)
389
+
390
+ albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
391
+
392
+ if mask.any():
393
+ print(f"[INFO] querying texture...")
394
+
395
+ xyzs = xyzs[mask] # [M, 3]
396
+
397
+ # batched inference to avoid OOM
398
+ batch = []
399
+ head = 0
400
+ while head < xyzs.shape[0]:
401
+ tail = min(head + 640000, xyzs.shape[0])
402
+ batch.append(torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float())
403
+ head += 640000
404
+
405
+ albedo[mask] = torch.cat(batch, dim=0)
406
+
407
+ albedo = albedo.view(h, w, -1)
408
+ mask = mask.view(h, w)
409
+ albedo = uv_padding(albedo, mask, padding)
410
+
411
+ # optimize texture
412
+ self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
413
+
414
+ optimizer = torch.optim.Adam([
415
+ {'params': self.albedo, 'lr': 1e-3},
416
+ ])
417
+
418
+ print(f"[INFO] fitting mesh texture...")
419
+ pbar = tqdm.trange(iters)
420
+ for i in pbar:
421
+
422
+ # shrink to front view as we care more about it...
423
+ ver = np.random.randint(-5, 5)
424
+ hor = np.random.randint(-15, 15)
425
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
426
+
427
+ pose = orbit_camera(ver, hor, rad)
428
+
429
+ image_gt, alpha_gt = self.render_gs(pose)
430
+ image_pred, alpha_pred = self.render_mesh(pose)
431
+
432
+ loss_mse = F.mse_loss(image_pred, image_gt)
433
+ loss = loss_mse
434
+
435
+ loss.backward()
436
+
437
+ optimizer.step()
438
+ optimizer.zero_grad()
439
+
440
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
441
+
442
+ print(f"[INFO] finished fitting mesh texture!")
443
+
444
+
445
+ @torch.no_grad()
446
+ def export_mesh(self, path):
447
+
448
+ mesh = Mesh(v=self.v, f=self.f, vt=self.vt, ft=self.ft, albedo=torch.sigmoid(self.albedo), device=self.device)
449
+ mesh.auto_normal()
450
+ mesh.write(path)
451
+
452
+
453
+ opt = tyro.cli(AllConfigs)
454
+
455
+ # load a saved ply and convert to mesh
456
+ assert opt.test_path.endswith('.ply'), '--test_path must be a .ply file saved by infer.py'
457
+
458
+ converter = Converter(opt).cuda()
459
+ converter.fit_nerf()
460
+ converter.fit_mesh()
461
+ converter.fit_mesh_uv()
462
+ converter.export_mesh(opt.test_path.replace('.ply', '.glb'))
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (125 Bytes). View file
 
core/__pycache__/gs.cpython-310.pyc ADDED
Binary file (5.44 kB). View file
 
core/__pycache__/options.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
core/attention.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import os
11
+ import warnings
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
17
+ try:
18
+ if XFORMERS_ENABLED:
19
+ from xformers.ops import memory_efficient_attention, unbind
20
+
21
+ XFORMERS_AVAILABLE = True
22
+ warnings.warn("xFormers is available (Attention)")
23
+ else:
24
+ warnings.warn("xFormers is disabled (Attention)")
25
+ raise ImportError
26
+ except ImportError:
27
+ XFORMERS_AVAILABLE = False
28
+ warnings.warn("xFormers is not available (Attention)")
29
+
30
+
31
+ class Attention(nn.Module):
32
+ def __init__(
33
+ self,
34
+ dim: int,
35
+ num_heads: int = 8,
36
+ qkv_bias: bool = False,
37
+ proj_bias: bool = True,
38
+ attn_drop: float = 0.0,
39
+ proj_drop: float = 0.0,
40
+ ) -> None:
41
+ super().__init__()
42
+ self.num_heads = num_heads
43
+ head_dim = dim // num_heads
44
+ self.scale = head_dim**-0.5
45
+
46
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
47
+ self.attn_drop = nn.Dropout(attn_drop)
48
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
49
+ self.proj_drop = nn.Dropout(proj_drop)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ B, N, C = x.shape
53
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
54
+
55
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
56
+ attn = q @ k.transpose(-2, -1)
57
+
58
+ attn = attn.softmax(dim=-1)
59
+ attn = self.attn_drop(attn)
60
+
61
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
62
+ x = self.proj(x)
63
+ x = self.proj_drop(x)
64
+ return x
65
+
66
+
67
+ class MemEffAttention(Attention):
68
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
69
+ if not XFORMERS_AVAILABLE:
70
+ if attn_bias is not None:
71
+ raise AssertionError("xFormers is required for using nested tensors")
72
+ return super().forward(x)
73
+
74
+ B, N, C = x.shape
75
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
76
+
77
+ q, k, v = unbind(qkv, 2)
78
+
79
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
80
+ x = x.reshape([B, N, C])
81
+
82
+ x = self.proj(x)
83
+ x = self.proj_drop(x)
84
+ return x
85
+
86
+
87
+ class CrossAttention(nn.Module):
88
+ def __init__(
89
+ self,
90
+ dim: int,
91
+ dim_q: int,
92
+ dim_k: int,
93
+ dim_v: int,
94
+ num_heads: int = 8,
95
+ qkv_bias: bool = False,
96
+ proj_bias: bool = True,
97
+ attn_drop: float = 0.0,
98
+ proj_drop: float = 0.0,
99
+ ) -> None:
100
+ super().__init__()
101
+ self.dim = dim
102
+ self.num_heads = num_heads
103
+ head_dim = dim // num_heads
104
+ self.scale = head_dim**-0.5
105
+
106
+ self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias)
107
+ self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias)
108
+ self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias)
109
+ self.attn_drop = nn.Dropout(attn_drop)
110
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
111
+ self.proj_drop = nn.Dropout(proj_drop)
112
+
113
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
114
+ # q: [B, N, Cq]
115
+ # k: [B, M, Ck]
116
+ # v: [B, M, Cv]
117
+ # return: [B, N, C]
118
+
119
+ B, N, _ = q.shape
120
+ M = k.shape[1]
121
+
122
+ q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, N, C/nh]
123
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
124
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads).permute(0, 2, 1, 3) # [B, nh, M, C/nh]
125
+
126
+ attn = q @ k.transpose(-2, -1) # [B, nh, N, M]
127
+
128
+ attn = attn.softmax(dim=-1) # [B, nh, N, M]
129
+ attn = self.attn_drop(attn)
130
+
131
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1) # [B, nh, N, M] @ [B, nh, M, C/nh] --> [B, nh, N, C/nh] --> [B, N, nh, C/nh] --> [B, N, C]
132
+ x = self.proj(x)
133
+ x = self.proj_drop(x)
134
+ return x
135
+
136
+
137
+ class MemEffCrossAttention(CrossAttention):
138
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, attn_bias=None) -> Tensor:
139
+ if not XFORMERS_AVAILABLE:
140
+ if attn_bias is not None:
141
+ raise AssertionError("xFormers is required for using nested tensors")
142
+ return super().forward(x)
143
+
144
+ B, N, _ = q.shape
145
+ M = k.shape[1]
146
+
147
+ q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh]
148
+ k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
149
+ v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh]
150
+
151
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
152
+ x = x.reshape(B, N, -1)
153
+
154
+ x = self.proj(x)
155
+ x = self.proj_drop(x)
156
+ return x
core/gs.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from diff_gaussian_rasterization import (
8
+ GaussianRasterizationSettings,
9
+ GaussianRasterizer,
10
+ )
11
+
12
+ from core.options import Options
13
+
14
+ import kiui
15
+
16
+ class GaussianRenderer:
17
+ def __init__(self, opt: Options):
18
+
19
+ self.opt = opt
20
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
21
+
22
+ # intrinsics
23
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
24
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
25
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
26
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
27
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
28
+ self.proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear)
29
+ self.proj_matrix[2, 3] = 1
30
+
31
+ def render(self, gaussians, cam_view, cam_view_proj, cam_pos, bg_color=None, scale_modifier=1):
32
+ # gaussians: [B, N, 14]
33
+ # cam_view, cam_view_proj: [B, V, 4, 4]
34
+ # cam_pos: [B, V, 3]
35
+
36
+ device = gaussians.device
37
+ B, V = cam_view.shape[:2]
38
+
39
+ # loop of loop...
40
+ images = []
41
+ alphas = []
42
+ for b in range(B):
43
+
44
+ # pos, opacity, scale, rotation, shs
45
+ means3D = gaussians[b, :, 0:3].contiguous().float()
46
+ opacity = gaussians[b, :, 3:4].contiguous().float()
47
+ scales = gaussians[b, :, 4:7].contiguous().float()
48
+ rotations = gaussians[b, :, 7:11].contiguous().float()
49
+ rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
50
+
51
+ for v in range(V):
52
+
53
+ # render novel views
54
+ view_matrix = cam_view[b, v].float()
55
+ view_proj_matrix = cam_view_proj[b, v].float()
56
+ campos = cam_pos[b, v].float()
57
+
58
+ raster_settings = GaussianRasterizationSettings(
59
+ image_height=self.opt.output_size,
60
+ image_width=self.opt.output_size,
61
+ tanfovx=self.tan_half_fov,
62
+ tanfovy=self.tan_half_fov,
63
+ bg=self.bg_color if bg_color is None else bg_color,
64
+ scale_modifier=scale_modifier,
65
+ viewmatrix=view_matrix,
66
+ projmatrix=view_proj_matrix,
67
+ sh_degree=0,
68
+ campos=campos,
69
+ prefiltered=False,
70
+ debug=False,
71
+ )
72
+
73
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
74
+
75
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
76
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
77
+ means3D=means3D,
78
+ means2D=torch.zeros_like(means3D, dtype=torch.float32, device=device),
79
+ shs=None,
80
+ colors_precomp=rgbs,
81
+ opacities=opacity,
82
+ scales=scales,
83
+ rotations=rotations,
84
+ cov3D_precomp=None,
85
+ )
86
+
87
+ rendered_image = rendered_image.clamp(0, 1)
88
+
89
+ images.append(rendered_image)
90
+ alphas.append(rendered_alpha)
91
+
92
+ images = torch.stack(images, dim=0).view(B, V, 3, self.opt.output_size, self.opt.output_size)
93
+ alphas = torch.stack(alphas, dim=0).view(B, V, 1, self.opt.output_size, self.opt.output_size)
94
+
95
+ return {
96
+ "image": images, # [B, V, 3, H, W]
97
+ "alpha": alphas, # [B, V, 1, H, W]
98
+ }
99
+
100
+
101
+ def save_ply(self, gaussians, path, compatible=True):
102
+ # gaussians: [B, N, 14]
103
+ # compatible: save pre-activated gaussians as in the original paper
104
+
105
+ assert gaussians.shape[0] == 1, 'only support batch size 1'
106
+
107
+ from plyfile import PlyData, PlyElement
108
+
109
+ means3D = gaussians[0, :, 0:3].contiguous().float()
110
+ opacity = gaussians[0, :, 3:4].contiguous().float()
111
+ scales = gaussians[0, :, 4:7].contiguous().float()
112
+ rotations = gaussians[0, :, 7:11].contiguous().float()
113
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
114
+
115
+ # prune by opacity
116
+ mask = opacity.squeeze(-1) >= 0.005
117
+ means3D = means3D[mask]
118
+ opacity = opacity[mask]
119
+ scales = scales[mask]
120
+ rotations = rotations[mask]
121
+ shs = shs[mask]
122
+
123
+ # invert activation to make it compatible with the original ply format
124
+ if compatible:
125
+ opacity = kiui.op.inverse_sigmoid(opacity)
126
+ scales = torch.log(scales + 1e-8)
127
+ shs = (shs - 0.5) / 0.28209479177387814
128
+
129
+ xyzs = means3D.detach().cpu().numpy()
130
+ f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
131
+ opacities = opacity.detach().cpu().numpy()
132
+ scales = scales.detach().cpu().numpy()
133
+ rotations = rotations.detach().cpu().numpy()
134
+
135
+ l = ['x', 'y', 'z']
136
+ # All channels except the 3 DC
137
+ for i in range(f_dc.shape[1]):
138
+ l.append('f_dc_{}'.format(i))
139
+ l.append('opacity')
140
+ for i in range(scales.shape[1]):
141
+ l.append('scale_{}'.format(i))
142
+ for i in range(rotations.shape[1]):
143
+ l.append('rot_{}'.format(i))
144
+
145
+ dtype_full = [(attribute, 'f4') for attribute in l]
146
+
147
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
148
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
149
+ elements[:] = list(map(tuple, attributes))
150
+ el = PlyElement.describe(elements, 'vertex')
151
+
152
+ PlyData([el]).write(path)
153
+
154
+ def load_ply(self, path, compatible=True):
155
+
156
+ from plyfile import PlyData, PlyElement
157
+
158
+ plydata = PlyData.read(path)
159
+
160
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
161
+ np.asarray(plydata.elements[0]["y"]),
162
+ np.asarray(plydata.elements[0]["z"])), axis=1)
163
+ print("Number of points at loading : ", xyz.shape[0])
164
+
165
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
166
+
167
+ shs = np.zeros((xyz.shape[0], 3))
168
+ shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
169
+ shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
170
+ shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
171
+
172
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
173
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
174
+ for idx, attr_name in enumerate(scale_names):
175
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
176
+
177
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")]
178
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
179
+ for idx, attr_name in enumerate(rot_names):
180
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
181
+
182
+ gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
183
+ gaussians = torch.from_numpy(gaussians).float() # cpu
184
+
185
+ if compatible:
186
+ gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
187
+ gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
188
+ gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
189
+
190
+ return gaussians
core/models.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ import kiui
7
+ from kiui.lpips import LPIPS
8
+
9
+ from core.unet import UNet
10
+ from core.options import Options
11
+ from core.gs import GaussianRenderer
12
+
13
+
14
+ class LGM(nn.Module):
15
+ def __init__(
16
+ self,
17
+ opt: Options,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.opt = opt
22
+
23
+ # unet
24
+ self.unet = UNet(
25
+ 9, 14,
26
+ down_channels=self.opt.down_channels,
27
+ down_attention=self.opt.down_attention,
28
+ mid_attention=self.opt.mid_attention,
29
+ up_channels=self.opt.up_channels,
30
+ up_attention=self.opt.up_attention,
31
+ )
32
+
33
+ # last conv
34
+ self.conv = nn.Conv2d(14, 14, kernel_size=1) # NOTE: maybe remove it if train again
35
+
36
+ # Gaussian Renderer
37
+ self.gs = GaussianRenderer(opt)
38
+
39
+ # activations...
40
+ self.pos_act = lambda x: x.clamp(-1, 1)
41
+ self.scale_act = lambda x: 0.1 * F.softplus(x)
42
+ self.opacity_act = lambda x: torch.sigmoid(x)
43
+ self.rot_act = F.normalize
44
+ self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
45
+
46
+ # LPIPS loss
47
+ if self.opt.lambda_lpips > 0:
48
+ self.lpips_loss = LPIPS(net='vgg')
49
+ self.lpips_loss.requires_grad_(False)
50
+
51
+
52
+ def state_dict(self, **kwargs):
53
+ # remove lpips_loss
54
+ state_dict = super().state_dict(**kwargs)
55
+ for k in list(state_dict.keys()):
56
+ if 'lpips_loss' in k:
57
+ del state_dict[k]
58
+ return state_dict
59
+
60
+
61
+ def prepare_default_rays(self, device, elevation=0):
62
+
63
+ from kiui.cam import orbit_camera
64
+ from core.utils import get_rays
65
+
66
+ cam_poses = np.stack([
67
+ orbit_camera(elevation, 0, radius=self.opt.cam_radius),
68
+ orbit_camera(elevation, 90, radius=self.opt.cam_radius),
69
+ orbit_camera(elevation, 180, radius=self.opt.cam_radius),
70
+ orbit_camera(elevation, 270, radius=self.opt.cam_radius),
71
+ ], axis=0) # [4, 4, 4]
72
+ cam_poses = torch.from_numpy(cam_poses)
73
+
74
+ rays_embeddings = []
75
+ for i in range(cam_poses.shape[0]):
76
+ rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
77
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
78
+ rays_embeddings.append(rays_plucker)
79
+
80
+ ## visualize rays for plotting figure
81
+ # kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
82
+
83
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
84
+
85
+ return rays_embeddings
86
+
87
+
88
+ def forward_gaussians(self, images):
89
+ # images: [B, 4, 9, H, W]
90
+ # return: Gaussians: [B, dim_t]
91
+
92
+ B, V, C, H, W = images.shape
93
+ images = images.view(B*V, C, H, W)
94
+
95
+ x = self.unet(images) # [B*4, 14, h, w]
96
+ x = self.conv(x) # [B*4, 14, h, w]
97
+
98
+ x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
99
+
100
+ ## visualize multi-view gaussian features for plotting figure
101
+ # tmp_alpha = self.opacity_act(x[0, :, 3:4])
102
+ # tmp_img_rgb = self.rgb_act(x[0, :, 11:]) * tmp_alpha + (1 - tmp_alpha)
103
+ # tmp_img_pos = self.pos_act(x[0, :, 0:3]) * 0.5 + 0.5
104
+ # kiui.vis.plot_image(tmp_img_rgb, save=True)
105
+ # kiui.vis.plot_image(tmp_img_pos, save=True)
106
+
107
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
108
+
109
+ pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
110
+ opacity = self.opacity_act(x[..., 3:4])
111
+ scale = self.scale_act(x[..., 4:7])
112
+ rotation = self.rot_act(x[..., 7:11])
113
+ rgbs = self.rgb_act(x[..., 11:])
114
+
115
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
116
+
117
+ return gaussians
118
+
119
+
120
+ def forward(self, data, step_ratio=1):
121
+ # data: output of the dataloader
122
+ # return: loss
123
+
124
+ results = {}
125
+ loss = 0
126
+
127
+ images = data['input'] # [B, 4, 9, h, W], input features
128
+
129
+ # use the first view to predict gaussians
130
+ gaussians = self.forward_gaussians(images) # [B, N, 14]
131
+
132
+ results['gaussians'] = gaussians
133
+
134
+ # random bg for training
135
+ if self.training:
136
+ bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device)
137
+ else:
138
+ bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
139
+
140
+ # use the other views for rendering and supervision
141
+ results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
142
+ pred_images = results['image'] # [B, V, C, output_size, output_size]
143
+ pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
144
+
145
+ results['images_pred'] = pred_images
146
+ results['alphas_pred'] = pred_alphas
147
+
148
+ gt_images = data['images_output'] # [B, V, 3, output_size, output_size], ground-truth novel views
149
+ gt_masks = data['masks_output'] # [B, V, 1, output_size, output_size], ground-truth masks
150
+
151
+ gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
152
+
153
+ loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
154
+ loss = loss + loss_mse
155
+
156
+ if self.opt.lambda_lpips > 0:
157
+ loss_lpips = self.lpips_loss(
158
+ # gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
159
+ # pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1,
160
+ # downsampled to at most 256 to reduce memory cost
161
+ F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
162
+ F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
163
+ ).mean()
164
+ results['loss_lpips'] = loss_lpips
165
+ loss = loss + self.opt.lambda_lpips * loss_lpips
166
+
167
+ results['loss'] = loss
168
+
169
+ # metric
170
+ with torch.no_grad():
171
+ psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
172
+ results['psnr'] = psnr
173
+
174
+ return results
core/options.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tyro
2
+ from dataclasses import dataclass
3
+ from typing import Tuple, Literal, Dict, Optional
4
+
5
+
6
+ @dataclass
7
+ class Options:
8
+ ### model
9
+ # Unet image input size
10
+ input_size: int = 256
11
+ # Unet definition
12
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
13
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
14
+ mid_attention: bool = True
15
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
16
+ up_attention: Tuple[bool, ...] = (True, True, True, False)
17
+ # Unet output size, dependent on the input_size and U-Net structure!
18
+ splat_size: int = 64
19
+ # gaussian render size
20
+ output_size: int = 256
21
+
22
+ ### dataset
23
+ # data mode (only support s3 now)
24
+ data_mode: Literal['s3'] = 's3'
25
+ # fovy of the dataset
26
+ fovy: float = 49.1
27
+ # camera near plane
28
+ znear: float = 0.5
29
+ # camera far plane
30
+ zfar: float = 2.5
31
+ # number of all views (input + output)
32
+ num_views: int = 12
33
+ # number of views
34
+ num_input_views: int = 4
35
+ # camera radius
36
+ cam_radius: float = 1.5 # to better use [-1, 1]^3 space
37
+ # num workers
38
+ num_workers: int = 8
39
+
40
+ ### training
41
+ # workspace
42
+ workspace: str = './workspace'
43
+ # resume
44
+ resume: Optional[str] = None
45
+ # batch size (per-GPU)
46
+ batch_size: int = 8
47
+ # gradient accumulation
48
+ gradient_accumulation_steps: int = 1
49
+ # training epochs
50
+ num_epochs: int = 30
51
+ # lpips loss weight
52
+ lambda_lpips: float = 1.0
53
+ # gradient clip
54
+ gradient_clip: float = 1.0
55
+ # mixed precision
56
+ mixed_precision: str = 'bf16'
57
+ # learning rate
58
+ lr: float = 4e-4
59
+ # augmentation prob for grid distortion
60
+ prob_grid_distortion: float = 0.5
61
+ # augmentation prob for camera jitter
62
+ prob_cam_jitter: float = 0.5
63
+
64
+ ### testing
65
+ # test image path
66
+ test_path: Optional[str] = None
67
+
68
+ ### misc
69
+ # nvdiffrast backend setting
70
+ force_cuda_rast: bool = False
71
+ # render fancy video with gaussian scaling effect
72
+ fancy_video: bool = False
73
+
74
+
75
+ # all the default settings
76
+ config_defaults: Dict[str, Options] = {}
77
+ config_doc: Dict[str, str] = {}
78
+
79
+ config_doc['lrm'] = 'the default settings for LGM'
80
+ config_defaults['lrm'] = Options()
81
+
82
+ config_doc['small'] = 'small model with lower resolution Gaussians'
83
+ config_defaults['small'] = Options(
84
+ input_size=256,
85
+ splat_size=64,
86
+ output_size=256,
87
+ batch_size=8,
88
+ gradient_accumulation_steps=1,
89
+ mixed_precision='bf16',
90
+ )
91
+
92
+ config_doc['big'] = 'big model with higher resolution Gaussians'
93
+ config_defaults['big'] = Options(
94
+ input_size=256,
95
+ up_channels=(1024, 1024, 512, 256, 128), # one more decoder
96
+ up_attention=(True, True, True, False, False),
97
+ splat_size=128,
98
+ output_size=512, # render & supervise Gaussians at a higher resolution.
99
+ batch_size=8,
100
+ num_views=8,
101
+ gradient_accumulation_steps=1,
102
+ mixed_precision='bf16',
103
+ )
104
+
105
+ config_doc['tiny'] = 'tiny model for ablation'
106
+ config_defaults['tiny'] = Options(
107
+ input_size=256,
108
+ down_channels=(32, 64, 128, 256, 512),
109
+ down_attention=(False, False, False, False, True),
110
+ up_channels=(512, 256, 128),
111
+ up_attention=(True, False, False, False),
112
+ splat_size=64,
113
+ output_size=256,
114
+ batch_size=16,
115
+ num_views=8,
116
+ gradient_accumulation_steps=1,
117
+ mixed_precision='bf16',
118
+ )
119
+
120
+ AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
core/provider_objaverse.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import random
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from torch.utils.data import Dataset
11
+
12
+ import kiui
13
+ from core.options import Options
14
+ from core.utils import get_rays, grid_distortion, orbit_camera_jitter
15
+
16
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
18
+
19
+
20
+ class ObjaverseDataset(Dataset):
21
+
22
+ def _warn(self):
23
+ raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)')
24
+
25
+ def __init__(self, opt: Options, training=True):
26
+
27
+ self.opt = opt
28
+ self.training = training
29
+
30
+ # TODO: remove this barrier
31
+ self._warn()
32
+
33
+ # TODO: load the list of objects for training
34
+ self.items = []
35
+ with open('TODO: file containing the list', 'r') as f:
36
+ for line in f.readlines():
37
+ self.items.append(line.strip())
38
+
39
+ # naive split
40
+ if self.training:
41
+ self.items = self.items[:-self.opt.batch_size]
42
+ else:
43
+ self.items = self.items[-self.opt.batch_size:]
44
+
45
+ # default camera intrinsics
46
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
47
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
48
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
49
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
50
+ self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear)
51
+ self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear)
52
+ self.proj_matrix[2, 3] = 1
53
+
54
+
55
+ def __len__(self):
56
+ return len(self.items)
57
+
58
+ def __getitem__(self, idx):
59
+
60
+ uid = self.items[idx]
61
+ results = {}
62
+
63
+ # load num_views images
64
+ images = []
65
+ masks = []
66
+ cam_poses = []
67
+
68
+ vid_cnt = 0
69
+
70
+ # TODO: choose views, based on your rendering settings
71
+ if self.training:
72
+ # input views are in (36, 72), other views are randomly selected
73
+ vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist()
74
+ else:
75
+ # fixed views
76
+ vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist()
77
+
78
+ for vid in vids:
79
+
80
+ image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png')
81
+ camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt')
82
+
83
+ try:
84
+ # TODO: load data (modify self.client here)
85
+ image = np.frombuffer(self.client.get(image_path), np.uint8)
86
+ image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) # [512, 512, 4] in [0, 1]
87
+ c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')]
88
+ c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4)
89
+ except Exception as e:
90
+ # print(f'[WARN] dataset {uid} {vid}: {e}')
91
+ continue
92
+
93
+ # TODO: you may have a different camera system
94
+ # blender world + opencv cam --> opengl world & cam
95
+ c2w[1] *= -1
96
+ c2w[[1, 2]] = c2w[[2, 1]]
97
+ c2w[:3, 1:3] *= -1 # invert up and forward direction
98
+
99
+ # scale up radius to fully use the [-1, 1]^3 space!
100
+ c2w[:3, 3] *= self.opt.cam_radius / 1.5 # 1.5 is the default scale
101
+
102
+ image = image.permute(2, 0, 1) # [4, 512, 512]
103
+ mask = image[3:4] # [1, 512, 512]
104
+ image = image[:3] * mask + (1 - mask) # [3, 512, 512], to white bg
105
+ image = image[[2,1,0]].contiguous() # bgr to rgb
106
+
107
+ images.append(image)
108
+ masks.append(mask.squeeze(0))
109
+ cam_poses.append(c2w)
110
+
111
+ vid_cnt += 1
112
+ if vid_cnt == self.opt.num_views:
113
+ break
114
+
115
+ if vid_cnt < self.opt.num_views:
116
+ print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!')
117
+ n = self.opt.num_views - vid_cnt
118
+ images = images + [images[-1]] * n
119
+ masks = masks + [masks[-1]] * n
120
+ cam_poses = cam_poses + [cam_poses[-1]] * n
121
+
122
+ images = torch.stack(images, dim=0) # [V, C, H, W]
123
+ masks = torch.stack(masks, dim=0) # [V, H, W]
124
+ cam_poses = torch.stack(cam_poses, dim=0) # [V, 4, 4]
125
+
126
+ # normalized camera feats as in paper (transform the first pose to a fixed position)
127
+ transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0])
128
+ cam_poses = transform.unsqueeze(0) @ cam_poses # [V, 4, 4]
129
+
130
+ images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) # [V, C, H, W]
131
+ cam_poses_input = cam_poses[:self.opt.num_input_views].clone()
132
+
133
+ # data augmentation
134
+ if self.training:
135
+ # apply random grid distortion to simulate 3D inconsistency
136
+ if random.random() < self.opt.prob_grid_distortion:
137
+ images_input[1:] = grid_distortion(images_input[1:])
138
+ # apply camera jittering (only to input!)
139
+ if random.random() < self.opt.prob_cam_jitter:
140
+ cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:])
141
+
142
+ images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
143
+
144
+ # resize render ground-truth images, range still in [0, 1]
145
+ results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, C, output_size, output_size]
146
+ results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) # [V, 1, output_size, output_size]
147
+
148
+ # build rays for input views
149
+ rays_embeddings = []
150
+ for i in range(self.opt.num_input_views):
151
+ rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
152
+ rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
153
+ rays_embeddings.append(rays_plucker)
154
+
155
+
156
+ rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() # [V, 6, h, w]
157
+ final_input = torch.cat([images_input, rays_embeddings], dim=1) # [V=4, 9, H, W]
158
+ results['input'] = final_input
159
+
160
+ # opengl to colmap camera for gaussian renderer
161
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
162
+
163
+ # cameras needed by gaussian rasterizer
164
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
165
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
166
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
167
+
168
+ results['cam_view'] = cam_view
169
+ results['cam_view_proj'] = cam_view_proj
170
+ results['cam_pos'] = cam_pos
171
+
172
+ return results
core/unet.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import numpy as np
6
+ from typing import Tuple, Literal
7
+ from functools import partial
8
+
9
+ from core.attention import MemEffAttention
10
+
11
+ class MVAttention(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ num_heads: int = 8,
16
+ qkv_bias: bool = False,
17
+ proj_bias: bool = True,
18
+ attn_drop: float = 0.0,
19
+ proj_drop: float = 0.0,
20
+ groups: int = 32,
21
+ eps: float = 1e-5,
22
+ residual: bool = True,
23
+ skip_scale: float = 1,
24
+ num_frames: int = 4, # WARN: hardcoded!
25
+ ):
26
+ super().__init__()
27
+
28
+ self.residual = residual
29
+ self.skip_scale = skip_scale
30
+ self.num_frames = num_frames
31
+
32
+ self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True)
33
+ self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop)
34
+
35
+ def forward(self, x):
36
+ # x: [B*V, C, H, W]
37
+ BV, C, H, W = x.shape
38
+ B = BV // self.num_frames # assert BV % self.num_frames == 0
39
+
40
+ res = x
41
+ x = self.norm(x)
42
+
43
+ x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
44
+ x = self.attn(x)
45
+ x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
46
+
47
+ if self.residual:
48
+ x = (x + res) * self.skip_scale
49
+ return x
50
+
51
+ class ResnetBlock(nn.Module):
52
+ def __init__(
53
+ self,
54
+ in_channels: int,
55
+ out_channels: int,
56
+ resample: Literal['default', 'up', 'down'] = 'default',
57
+ groups: int = 32,
58
+ eps: float = 1e-5,
59
+ skip_scale: float = 1, # multiplied to output
60
+ ):
61
+ super().__init__()
62
+
63
+ self.in_channels = in_channels
64
+ self.out_channels = out_channels
65
+ self.skip_scale = skip_scale
66
+
67
+ self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
68
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
69
+
70
+ self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
71
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
72
+
73
+ self.act = F.silu
74
+
75
+ self.resample = None
76
+ if resample == 'up':
77
+ self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
78
+ elif resample == 'down':
79
+ self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
80
+
81
+ self.shortcut = nn.Identity()
82
+ if self.in_channels != self.out_channels:
83
+ self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
84
+
85
+
86
+ def forward(self, x):
87
+ res = x
88
+
89
+ x = self.norm1(x)
90
+ x = self.act(x)
91
+
92
+ if self.resample:
93
+ res = self.resample(res)
94
+ x = self.resample(x)
95
+
96
+ x = self.conv1(x)
97
+ x = self.norm2(x)
98
+ x = self.act(x)
99
+ x = self.conv2(x)
100
+
101
+ x = (x + self.shortcut(res)) * self.skip_scale
102
+
103
+ return x
104
+
105
+ class DownBlock(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels: int,
109
+ out_channels: int,
110
+ num_layers: int = 1,
111
+ downsample: bool = True,
112
+ attention: bool = True,
113
+ attention_heads: int = 16,
114
+ skip_scale: float = 1,
115
+ ):
116
+ super().__init__()
117
+
118
+ nets = []
119
+ attns = []
120
+ for i in range(num_layers):
121
+ in_channels = in_channels if i == 0 else out_channels
122
+ nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
123
+ if attention:
124
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
125
+ else:
126
+ attns.append(None)
127
+ self.nets = nn.ModuleList(nets)
128
+ self.attns = nn.ModuleList(attns)
129
+
130
+ self.downsample = None
131
+ if downsample:
132
+ self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
133
+
134
+ def forward(self, x):
135
+ xs = []
136
+
137
+ for attn, net in zip(self.attns, self.nets):
138
+ x = net(x)
139
+ if attn:
140
+ x = attn(x)
141
+ xs.append(x)
142
+
143
+ if self.downsample:
144
+ x = self.downsample(x)
145
+ xs.append(x)
146
+
147
+ return x, xs
148
+
149
+
150
+ class MidBlock(nn.Module):
151
+ def __init__(
152
+ self,
153
+ in_channels: int,
154
+ num_layers: int = 1,
155
+ attention: bool = True,
156
+ attention_heads: int = 16,
157
+ skip_scale: float = 1,
158
+ ):
159
+ super().__init__()
160
+
161
+ nets = []
162
+ attns = []
163
+ # first layer
164
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
165
+ # more layers
166
+ for i in range(num_layers):
167
+ nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
168
+ if attention:
169
+ attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale))
170
+ else:
171
+ attns.append(None)
172
+ self.nets = nn.ModuleList(nets)
173
+ self.attns = nn.ModuleList(attns)
174
+
175
+ def forward(self, x):
176
+ x = self.nets[0](x)
177
+ for attn, net in zip(self.attns, self.nets[1:]):
178
+ if attn:
179
+ x = attn(x)
180
+ x = net(x)
181
+ return x
182
+
183
+
184
+ class UpBlock(nn.Module):
185
+ def __init__(
186
+ self,
187
+ in_channels: int,
188
+ prev_out_channels: int,
189
+ out_channels: int,
190
+ num_layers: int = 1,
191
+ upsample: bool = True,
192
+ attention: bool = True,
193
+ attention_heads: int = 16,
194
+ skip_scale: float = 1,
195
+ ):
196
+ super().__init__()
197
+
198
+ nets = []
199
+ attns = []
200
+ for i in range(num_layers):
201
+ cin = in_channels if i == 0 else out_channels
202
+ cskip = prev_out_channels if (i == num_layers - 1) else out_channels
203
+
204
+ nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
205
+ if attention:
206
+ attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale))
207
+ else:
208
+ attns.append(None)
209
+ self.nets = nn.ModuleList(nets)
210
+ self.attns = nn.ModuleList(attns)
211
+
212
+ self.upsample = None
213
+ if upsample:
214
+ self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
215
+
216
+ def forward(self, x, xs):
217
+
218
+ for attn, net in zip(self.attns, self.nets):
219
+ res_x = xs[-1]
220
+ xs = xs[:-1]
221
+ x = torch.cat([x, res_x], dim=1)
222
+ x = net(x)
223
+ if attn:
224
+ x = attn(x)
225
+
226
+ if self.upsample:
227
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest')
228
+ x = self.upsample(x)
229
+
230
+ return x
231
+
232
+
233
+ # it could be asymmetric!
234
+ class UNet(nn.Module):
235
+ def __init__(
236
+ self,
237
+ in_channels: int = 3,
238
+ out_channels: int = 3,
239
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
240
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True),
241
+ mid_attention: bool = True,
242
+ up_channels: Tuple[int, ...] = (1024, 512, 256),
243
+ up_attention: Tuple[bool, ...] = (True, True, False),
244
+ layers_per_block: int = 2,
245
+ skip_scale: float = np.sqrt(0.5),
246
+ ):
247
+ super().__init__()
248
+
249
+ # first
250
+ self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1)
251
+
252
+ # down
253
+ down_blocks = []
254
+ cout = down_channels[0]
255
+ for i in range(len(down_channels)):
256
+ cin = cout
257
+ cout = down_channels[i]
258
+
259
+ down_blocks.append(DownBlock(
260
+ cin, cout,
261
+ num_layers=layers_per_block,
262
+ downsample=(i != len(down_channels) - 1), # not final layer
263
+ attention=down_attention[i],
264
+ skip_scale=skip_scale,
265
+ ))
266
+ self.down_blocks = nn.ModuleList(down_blocks)
267
+
268
+ # mid
269
+ self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale)
270
+
271
+ # up
272
+ up_blocks = []
273
+ cout = up_channels[0]
274
+ for i in range(len(up_channels)):
275
+ cin = cout
276
+ cout = up_channels[i]
277
+ cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric
278
+
279
+ up_blocks.append(UpBlock(
280
+ cin, cskip, cout,
281
+ num_layers=layers_per_block + 1, # one more layer for up
282
+ upsample=(i != len(up_channels) - 1), # not final layer
283
+ attention=up_attention[i],
284
+ skip_scale=skip_scale,
285
+ ))
286
+ self.up_blocks = nn.ModuleList(up_blocks)
287
+
288
+ # last
289
+ self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5)
290
+ self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
291
+
292
+
293
+ def forward(self, x):
294
+ # x: [B, Cin, H, W]
295
+
296
+ # first
297
+ x = self.conv_in(x)
298
+
299
+ # down
300
+ xss = [x]
301
+ for block in self.down_blocks:
302
+ x, xs = block(x)
303
+ xss.extend(xs)
304
+
305
+ # mid
306
+ x = self.mid_block(x)
307
+
308
+ # up
309
+ for block in self.up_blocks:
310
+ xs = xss[-len(block.nets):]
311
+ xss = xss[:-len(block.nets)]
312
+ x = block(x, xs)
313
+
314
+ # last
315
+ x = self.norm_out(x)
316
+ x = F.silu(x)
317
+ x = self.conv_out(x) # [B, Cout, H', W']
318
+
319
+ return x
core/utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ import roma
8
+ from kiui.op import safe_normalize
9
+
10
+ def get_rays(pose, h, w, fovy, opengl=True):
11
+
12
+ x, y = torch.meshgrid(
13
+ torch.arange(w, device=pose.device),
14
+ torch.arange(h, device=pose.device),
15
+ indexing="xy",
16
+ )
17
+ x = x.flatten()
18
+ y = y.flatten()
19
+
20
+ cx = w * 0.5
21
+ cy = h * 0.5
22
+
23
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
24
+
25
+ camera_dirs = F.pad(
26
+ torch.stack(
27
+ [
28
+ (x - cx + 0.5) / focal,
29
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
30
+ ],
31
+ dim=-1,
32
+ ),
33
+ (0, 1),
34
+ value=(-1.0 if opengl else 1.0),
35
+ ) # [hw, 3]
36
+
37
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
38
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
39
+
40
+ rays_o = rays_o.view(h, w, 3)
41
+ rays_d = safe_normalize(rays_d).view(h, w, 3)
42
+
43
+ return rays_o, rays_d
44
+
45
+ def orbit_camera_jitter(poses, strength=0.1):
46
+ # poses: [B, 4, 4], assume orbit camera in opengl format
47
+ # random orbital rotate
48
+
49
+ B = poses.shape[0]
50
+ rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
51
+ rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
52
+
53
+ rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
54
+ R = rot @ poses[:, :3, :3]
55
+ T = rot @ poses[:, :3, 3:]
56
+
57
+ new_poses = poses.clone()
58
+ new_poses[:, :3, :3] = R
59
+ new_poses[:, :3, 3:] = T
60
+
61
+ return new_poses
62
+
63
+ def grid_distortion(images, strength=0.5):
64
+ # images: [B, C, H, W]
65
+ # num_steps: int, grid resolution for distortion
66
+ # strength: float in [0, 1], strength of distortion
67
+
68
+ B, C, H, W = images.shape
69
+
70
+ num_steps = np.random.randint(8, 17)
71
+ grid_steps = torch.linspace(-1, 1, num_steps)
72
+
73
+ # have to loop batch...
74
+ grids = []
75
+ for b in range(B):
76
+ # construct displacement
77
+ x_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
78
+ x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
79
+ x_steps = (x_steps * W).long() # [num_steps]
80
+ x_steps[0] = 0
81
+ x_steps[-1] = W
82
+ xs = []
83
+ for i in range(num_steps - 1):
84
+ xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
85
+ xs = torch.cat(xs, dim=0) # [W]
86
+
87
+ y_steps = torch.linspace(0, 1, num_steps) # [num_steps], inclusive
88
+ y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1) # perturb
89
+ y_steps = (y_steps * H).long() # [num_steps]
90
+ y_steps[0] = 0
91
+ y_steps[-1] = H
92
+ ys = []
93
+ for i in range(num_steps - 1):
94
+ ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
95
+ ys = torch.cat(ys, dim=0) # [H]
96
+
97
+ # construct grid
98
+ grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy') # [H, W]
99
+ grid = torch.stack([grid_x, grid_y], dim=-1) # [H, W, 2]
100
+
101
+ grids.append(grid)
102
+
103
+ grids = torch.stack(grids, dim=0).to(images.device) # [B, H, W, 2]
104
+
105
+ # grid sample
106
+ images = F.grid_sample(images, grids, align_corners=False)
107
+
108
+ return images
109
+
data_test/catstatue.ply ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29d36a368577579338f897b6098b4d96ab7d0bcf0c61ebb8249af56b72b0c7aa
3
+ size 2399737
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0+cu121
2
+ torchvision==0.16.0+cu121
3
+ torchaudio==2.1.0+cu121
4
+ tyro
5
+ PyMCubes
6
+ nerfacc
7
+ trimesh
8
+ pymeshlab
9
+ wheel
10
+ tqdm
11
+ opencv-python
12
+ ninja
13
+ plyfile
14
+ xatlas
15
+ scikit-learn
16
+ pygltflib
17
+ gradio
18
+ git+https://github.com/ashawkey/kiuikit.git
19
+ https://github.com/camenduru/LGM-replicate/releases/download/replicate/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl
20
+ https://github.com/camenduru/wheels/releases/download/colab/nvdiffrast-0.3.1-py3-none-any.whl