jiaweir commited on
Commit
cdc7dcc
β€’
1 Parent(s): c122ae9
Files changed (6) hide show
  1. app.py +8 -4
  2. configs/4d_demo.yaml +1 -1
  3. gs_renderer_4d.py +42 -0
  4. lgm/core/models.py +41 -0
  5. lgm/infer_demo.py +15 -34
  6. main_4d_demo.py +27 -32
app.py CHANGED
@@ -224,7 +224,7 @@ def optimize_stage_2(image_block: Image.Image, seed_slider: int):
224
  process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
225
  # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
226
  image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
227
- # return 'vis_data/tmp_rgba.mp4', [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.ply')]
228
  return [image_dir+f'/{t:03d}.ply' for t in range(28)]
229
 
230
 
@@ -256,7 +256,7 @@ if __name__ == "__main__":
256
 
257
  # Image-to-3D
258
  with gr.Row(variant='panel'):
259
- with gr.Column(scale=4):
260
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
261
 
262
  # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
@@ -282,8 +282,12 @@ if __name__ == "__main__":
282
  img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
283
 
284
  with gr.Column(scale=5):
285
- dirving_video = gr.Video(label="video",height=290)
286
- obj3d = gr.Video(label="3D Model",height=290)
 
 
 
 
287
  obj4d = Model4DGS(label="4D Model", height=500, fps=14)
288
 
289
 
 
224
  process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
225
  # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
226
  image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
227
+ # return os.path.join('vis_data', f'{img_hash}_rgba.mp4'), [image_dir+f'/{t:03d}.ply' for t in range(28)]
228
  return [image_dir+f'/{t:03d}.ply' for t in range(28)]
229
 
230
 
 
256
 
257
  # Image-to-3D
258
  with gr.Row(variant='panel'):
259
+ with gr.Column(scale=5):
260
  image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
261
 
262
  # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
 
282
  img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
283
 
284
  with gr.Column(scale=5):
285
+ with gr.Row():
286
+ with gr.Column(scale=5):
287
+ dirving_video = gr.Video(label="video",height=290)
288
+ with gr.Column(scale=5):
289
+ obj3d = gr.Video(label="3D Model",height=290)
290
+ # video4d = gr.Video(label="4D video",height=290)
291
  obj4d = Model4DGS(label="4D Model", height=500, fps=14)
292
 
293
 
configs/4d_demo.yaml CHANGED
@@ -30,7 +30,7 @@ lambda_svd: 0
30
  # training batch size per iter
31
  batch_size: 7
32
  # training iterations for stage 1
33
- iters: 300
34
  # training iterations for stage 2
35
  iters_refine: 50
36
  # training camera radius
 
30
  # training batch size per iter
31
  batch_size: 7
32
  # training iterations for stage 1
33
+ iters: 400
34
  # training iterations for stage 2
35
  iters_refine: 50
36
  # training camera radius
gs_renderer_4d.py CHANGED
@@ -150,6 +150,48 @@ class Renderer:
150
  self.opacity_deform_T = opacity_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
151
  self.scales_deform_T = scales_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
152
  self.rotations_deform_T = rotations_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
 
155
  def render(
 
150
  self.opacity_deform_T = opacity_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
151
  self.scales_deform_T = scales_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
152
  self.rotations_deform_T = rotations_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
153
+
154
+
155
+ def prepare_render_4x(
156
+ self,
157
+ ):
158
+ means3D = self.gaussians.get_xyz
159
+ opacity = self.gaussians._opacity
160
+ scales = self.gaussians._scaling
161
+ rotations = self.gaussians._rotation
162
+
163
+ means3D_T = []
164
+ opacity_T = []
165
+ scales_T = []
166
+ rotations_T = []
167
+ time_T = []
168
+
169
+ for t in range(self.T * 4):
170
+ tt = t / 4.
171
+ time = torch.tensor(tt).to(means3D.device).repeat(means3D.shape[0],1)
172
+ time = ((time.float() / self.T) - 0.5) * 2
173
+
174
+ means3D_T.append(means3D)
175
+ opacity_T.append(opacity)
176
+ scales_T.append(scales)
177
+ rotations_T.append(rotations)
178
+ time_T.append(time)
179
+
180
+ means3D_T = torch.cat(means3D_T)
181
+ opacity_T = torch.cat(opacity_T)
182
+ scales_T = torch.cat(scales_T)
183
+ rotations_T = torch.cat(rotations_T)
184
+ time_T = torch.cat(time_T)
185
+
186
+
187
+ means3D_deform_T, scales_deform_T, rotations_deform_T, opacity_deform_T = self.gaussians._deformation(means3D_T, scales_T,
188
+ rotations_T, opacity_T,
189
+ time_T) # time is not none
190
+ self.means3D_deform_T = means3D_deform_T.reshape([self.T *4, means3D_deform_T.shape[0]//self.T // 4, -1])
191
+ self.opacity_deform_T = opacity_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
192
+ self.scales_deform_T = scales_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
193
+ self.rotations_deform_T = rotations_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
194
+
195
 
196
 
197
  def render(
lgm/core/models.py CHANGED
@@ -116,6 +116,47 @@ class LGM(nn.Module):
116
 
117
  return gaussians
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def forward(self, data, step_ratio=1):
121
  # data: output of the dataloader
 
116
 
117
  return gaussians
118
 
119
+ def forward_gaussians_downsample(self, images):
120
+ # images: [B, 4, 9, H, W]
121
+ # return: Gaussians: [B, dim_t]
122
+
123
+ B, V, C, H, W = images.shape
124
+ images = images.view(B*V, C, H, W)
125
+
126
+ x = self.unet(images) # [B*4, 14, h, w]
127
+ x = self.conv(x) # [B*4, 14, h, w]
128
+
129
+ x_orig_res = x.clone()
130
+
131
+ x = F.interpolate(x, (self.opt.splat_size // 4, self.opt.splat_size//4), mode='nearest')
132
+ x = x.reshape(B, 4, 14, self.opt.splat_size//4, self.opt.splat_size//4)
133
+
134
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
135
+
136
+ pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
137
+ opacity = self.opacity_act(x[..., 3:4])
138
+ scale = self.scale_act(x[..., 4:7]) * 4
139
+ rotation = self.rot_act(x[..., 7:11])
140
+ rgbs = self.rgb_act(x[..., 11:])
141
+
142
+ gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
143
+
144
+
145
+ x = x_orig_res.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
146
+
147
+ x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
148
+
149
+ pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
150
+ opacity = self.opacity_act(x[..., 3:4])
151
+ scale = self.scale_act(x[..., 4:7])
152
+ rotation = self.rot_act(x[..., 7:11])
153
+ rgbs = self.rgb_act(x[..., 11:])
154
+
155
+ gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
156
+
157
+
158
+ return gaussians, gaussians_orig_res
159
+
160
 
161
  def forward(self, data, step_ratio=1):
162
  # data: output of the dataloader
lgm/infer_demo.py CHANGED
@@ -151,7 +151,7 @@ def process(opt: Options, path, pipe, model, rays_embeddings, seed):
151
 
152
  with torch.autocast(device_type='cuda', dtype=torch.float16):
153
  # generate gaussians
154
- gaussians = model.forward_gaussians(input_image)
155
 
156
  # save gaussians
157
  model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
@@ -160,39 +160,20 @@ def process(opt: Options, path, pipe, model, rays_embeddings, seed):
160
  images = []
161
  elevation = 0
162
 
163
- if opt.fancy_video:
164
-
165
- azimuth = np.arange(0, 720, 4, dtype=np.int32)
166
- for azi in tqdm.tqdm(azimuth):
167
-
168
- cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
169
-
170
- cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
171
-
172
- # cameras needed by gaussian rasterizer
173
- cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
174
- cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
175
- cam_pos = - cam_poses[:, :3, 3] # [V, 3]
176
-
177
- scale = min(azi / 360, 1)
178
-
179
- image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
180
- images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
181
- else:
182
- azimuth = np.arange(0, 360, 2, dtype=np.int32)
183
- for azi in tqdm.tqdm(azimuth):
184
-
185
- cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
186
-
187
- cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
188
-
189
- # cameras needed by gaussian rasterizer
190
- cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
191
- cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
192
- cam_pos = - cam_poses[:, :3, 3] # [V, 3]
193
-
194
- image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
195
- images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
196
 
197
  images = np.concatenate(images, axis=0)
198
  imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
 
151
 
152
  with torch.autocast(device_type='cuda', dtype=torch.float16):
153
  # generate gaussians
154
+ gaussians, gaussians_orig_res = model.forward_gaussians_downsample(input_image)
155
 
156
  # save gaussians
157
  model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
 
160
  images = []
161
  elevation = 0
162
 
163
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
164
+ for azi in tqdm.tqdm(azimuth):
165
+
166
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
167
+
168
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
169
+
170
+ # cameras needed by gaussian rasterizer
171
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
172
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
173
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
174
+
175
+ image = model.gs.render(gaussians_orig_res, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
176
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  images = np.concatenate(images, axis=0)
179
  imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
main_4d_demo.py CHANGED
@@ -540,38 +540,33 @@ class GUI:
540
 
541
  # render eval
542
  image_list =[]
543
- nframes = self.opt.batch_size * 7 + 15 * 7
544
- hor = 180
545
- delta_hor = 45 / 15
546
- delta_time = 1
547
- for i in range(8):
548
- time = 0
549
- for j in range(self.opt.batch_size + 15):
550
- pose = orbit_camera(self.opt.elevation, hor-180, self.opt.radius)
551
- cur_cam = MiniCam(
552
- pose,
553
- 512,
554
- 512,
555
- self.cam.fovy,
556
- self.cam.fovx,
557
- self.cam.near,
558
- self.cam.far,
559
- time=time
560
- )
561
- with torch.no_grad():
562
- outputs = self.renderer.render(cur_cam)
563
-
564
- out = outputs["image"].cpu().detach().numpy().astype(np.float32)
565
- out = np.transpose(out, (1, 2, 0))
566
- out = np.uint8(out*255)
567
- image_list.append(out)
568
-
569
- time = (time + delta_time) % self.opt.batch_size
570
- if j >= self.opt.batch_size:
571
- hor = (hor+delta_hor) % 360
572
-
573
-
574
- imageio.mimwrite(f'vis_data/{self.opt.save_path}.mp4', image_list, fps=7)
575
 
576
  if self.gui:
577
  while True:
 
540
 
541
  # render eval
542
  image_list =[]
543
+ fps = 14
544
+ delta_time = 1 / 30
545
+ self.renderer.prepare_render_4x()
546
+ time = 0
547
+ for hor in range(720):
548
+
549
+ pose = orbit_camera(self.opt.elevation, hor, self.opt.radius)
550
+ cur_cam = MiniCam(
551
+ pose,
552
+ 512,
553
+ 512,
554
+ self.cam.fovy,
555
+ self.cam.fovx,
556
+ self.cam.near,
557
+ self.cam.far,
558
+ time=int(time * fps) % (self.opt.batch_size * 4)
559
+ )
560
+ with torch.no_grad():
561
+ outputs = self.renderer.render(cur_cam)
562
+ out = outputs["image"].cpu().detach().numpy().astype(np.float32)
563
+ out = np.transpose(out, (1, 2, 0))
564
+ out = np.uint8(out*255)
565
+ image_list.append(out)
566
+ time += delta_time
567
+
568
+
569
+ imageio.mimwrite(f'vis_data/{self.opt.save_path}.mp4', image_list, fps=30)
 
 
 
 
 
570
 
571
  if self.gui:
572
  while True: