Spaces:
Runtime error
Runtime error
jiaweir
commited on
Commit
β’
cdc7dcc
1
Parent(s):
c122ae9
optimize
Browse files- app.py +8 -4
- configs/4d_demo.yaml +1 -1
- gs_renderer_4d.py +42 -0
- lgm/core/models.py +41 -0
- lgm/infer_demo.py +15 -34
- main_4d_demo.py +27 -32
app.py
CHANGED
@@ -224,7 +224,7 @@ def optimize_stage_2(image_block: Image.Image, seed_slider: int):
|
|
224 |
process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
|
225 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
226 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
227 |
-
# return 'vis_data
|
228 |
return [image_dir+f'/{t:03d}.ply' for t in range(28)]
|
229 |
|
230 |
|
@@ -256,7 +256,7 @@ if __name__ == "__main__":
|
|
256 |
|
257 |
# Image-to-3D
|
258 |
with gr.Row(variant='panel'):
|
259 |
-
with gr.Column(scale=
|
260 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
261 |
|
262 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
@@ -282,8 +282,12 @@ if __name__ == "__main__":
|
|
282 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
283 |
|
284 |
with gr.Column(scale=5):
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
287 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
288 |
|
289 |
|
|
|
224 |
process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123)
|
225 |
# os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames'))
|
226 |
image_dir = os.path.join('logs', f'{img_hash}_rgba_frames')
|
227 |
+
# return os.path.join('vis_data', f'{img_hash}_rgba.mp4'), [image_dir+f'/{t:03d}.ply' for t in range(28)]
|
228 |
return [image_dir+f'/{t:03d}.ply' for t in range(28)]
|
229 |
|
230 |
|
|
|
256 |
|
257 |
# Image-to-3D
|
258 |
with gr.Row(variant='panel'):
|
259 |
+
with gr.Column(scale=5):
|
260 |
image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image')
|
261 |
|
262 |
# elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle')
|
|
|
282 |
img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True)
|
283 |
|
284 |
with gr.Column(scale=5):
|
285 |
+
with gr.Row():
|
286 |
+
with gr.Column(scale=5):
|
287 |
+
dirving_video = gr.Video(label="video",height=290)
|
288 |
+
with gr.Column(scale=5):
|
289 |
+
obj3d = gr.Video(label="3D Model",height=290)
|
290 |
+
# video4d = gr.Video(label="4D video",height=290)
|
291 |
obj4d = Model4DGS(label="4D Model", height=500, fps=14)
|
292 |
|
293 |
|
configs/4d_demo.yaml
CHANGED
@@ -30,7 +30,7 @@ lambda_svd: 0
|
|
30 |
# training batch size per iter
|
31 |
batch_size: 7
|
32 |
# training iterations for stage 1
|
33 |
-
iters:
|
34 |
# training iterations for stage 2
|
35 |
iters_refine: 50
|
36 |
# training camera radius
|
|
|
30 |
# training batch size per iter
|
31 |
batch_size: 7
|
32 |
# training iterations for stage 1
|
33 |
+
iters: 400
|
34 |
# training iterations for stage 2
|
35 |
iters_refine: 50
|
36 |
# training camera radius
|
gs_renderer_4d.py
CHANGED
@@ -150,6 +150,48 @@ class Renderer:
|
|
150 |
self.opacity_deform_T = opacity_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
151 |
self.scales_deform_T = scales_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
152 |
self.rotations_deform_T = rotations_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
|
155 |
def render(
|
|
|
150 |
self.opacity_deform_T = opacity_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
151 |
self.scales_deform_T = scales_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
152 |
self.rotations_deform_T = rotations_deform_T.reshape([self.T, means3D_deform_T.shape[0]//self.T, -1])
|
153 |
+
|
154 |
+
|
155 |
+
def prepare_render_4x(
|
156 |
+
self,
|
157 |
+
):
|
158 |
+
means3D = self.gaussians.get_xyz
|
159 |
+
opacity = self.gaussians._opacity
|
160 |
+
scales = self.gaussians._scaling
|
161 |
+
rotations = self.gaussians._rotation
|
162 |
+
|
163 |
+
means3D_T = []
|
164 |
+
opacity_T = []
|
165 |
+
scales_T = []
|
166 |
+
rotations_T = []
|
167 |
+
time_T = []
|
168 |
+
|
169 |
+
for t in range(self.T * 4):
|
170 |
+
tt = t / 4.
|
171 |
+
time = torch.tensor(tt).to(means3D.device).repeat(means3D.shape[0],1)
|
172 |
+
time = ((time.float() / self.T) - 0.5) * 2
|
173 |
+
|
174 |
+
means3D_T.append(means3D)
|
175 |
+
opacity_T.append(opacity)
|
176 |
+
scales_T.append(scales)
|
177 |
+
rotations_T.append(rotations)
|
178 |
+
time_T.append(time)
|
179 |
+
|
180 |
+
means3D_T = torch.cat(means3D_T)
|
181 |
+
opacity_T = torch.cat(opacity_T)
|
182 |
+
scales_T = torch.cat(scales_T)
|
183 |
+
rotations_T = torch.cat(rotations_T)
|
184 |
+
time_T = torch.cat(time_T)
|
185 |
+
|
186 |
+
|
187 |
+
means3D_deform_T, scales_deform_T, rotations_deform_T, opacity_deform_T = self.gaussians._deformation(means3D_T, scales_T,
|
188 |
+
rotations_T, opacity_T,
|
189 |
+
time_T) # time is not none
|
190 |
+
self.means3D_deform_T = means3D_deform_T.reshape([self.T *4, means3D_deform_T.shape[0]//self.T // 4, -1])
|
191 |
+
self.opacity_deform_T = opacity_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
|
192 |
+
self.scales_deform_T = scales_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
|
193 |
+
self.rotations_deform_T = rotations_deform_T.reshape([self.T*4, means3D_deform_T.shape[0]//self.T//4, -1])
|
194 |
+
|
195 |
|
196 |
|
197 |
def render(
|
lgm/core/models.py
CHANGED
@@ -116,6 +116,47 @@ class LGM(nn.Module):
|
|
116 |
|
117 |
return gaussians
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def forward(self, data, step_ratio=1):
|
121 |
# data: output of the dataloader
|
|
|
116 |
|
117 |
return gaussians
|
118 |
|
119 |
+
def forward_gaussians_downsample(self, images):
|
120 |
+
# images: [B, 4, 9, H, W]
|
121 |
+
# return: Gaussians: [B, dim_t]
|
122 |
+
|
123 |
+
B, V, C, H, W = images.shape
|
124 |
+
images = images.view(B*V, C, H, W)
|
125 |
+
|
126 |
+
x = self.unet(images) # [B*4, 14, h, w]
|
127 |
+
x = self.conv(x) # [B*4, 14, h, w]
|
128 |
+
|
129 |
+
x_orig_res = x.clone()
|
130 |
+
|
131 |
+
x = F.interpolate(x, (self.opt.splat_size // 4, self.opt.splat_size//4), mode='nearest')
|
132 |
+
x = x.reshape(B, 4, 14, self.opt.splat_size//4, self.opt.splat_size//4)
|
133 |
+
|
134 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
135 |
+
|
136 |
+
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
137 |
+
opacity = self.opacity_act(x[..., 3:4])
|
138 |
+
scale = self.scale_act(x[..., 4:7]) * 4
|
139 |
+
rotation = self.rot_act(x[..., 7:11])
|
140 |
+
rgbs = self.rgb_act(x[..., 11:])
|
141 |
+
|
142 |
+
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
143 |
+
|
144 |
+
|
145 |
+
x = x_orig_res.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size)
|
146 |
+
|
147 |
+
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14)
|
148 |
+
|
149 |
+
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
|
150 |
+
opacity = self.opacity_act(x[..., 3:4])
|
151 |
+
scale = self.scale_act(x[..., 4:7])
|
152 |
+
rotation = self.rot_act(x[..., 7:11])
|
153 |
+
rgbs = self.rgb_act(x[..., 11:])
|
154 |
+
|
155 |
+
gaussians_orig_res = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, N, 14]
|
156 |
+
|
157 |
+
|
158 |
+
return gaussians, gaussians_orig_res
|
159 |
+
|
160 |
|
161 |
def forward(self, data, step_ratio=1):
|
162 |
# data: output of the dataloader
|
lgm/infer_demo.py
CHANGED
@@ -151,7 +151,7 @@ def process(opt: Options, path, pipe, model, rays_embeddings, seed):
|
|
151 |
|
152 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
153 |
# generate gaussians
|
154 |
-
gaussians = model.
|
155 |
|
156 |
# save gaussians
|
157 |
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
|
@@ -160,39 +160,20 @@ def process(opt: Options, path, pipe, model, rays_embeddings, seed):
|
|
160 |
images = []
|
161 |
elevation = 0
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
scale = min(azi / 360, 1)
|
178 |
-
|
179 |
-
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
|
180 |
-
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
181 |
-
else:
|
182 |
-
azimuth = np.arange(0, 360, 2, dtype=np.int32)
|
183 |
-
for azi in tqdm.tqdm(azimuth):
|
184 |
-
|
185 |
-
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
186 |
-
|
187 |
-
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
188 |
-
|
189 |
-
# cameras needed by gaussian rasterizer
|
190 |
-
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
191 |
-
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
192 |
-
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
193 |
-
|
194 |
-
image = model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
|
195 |
-
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
196 |
|
197 |
images = np.concatenate(images, axis=0)
|
198 |
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
|
|
|
151 |
|
152 |
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
153 |
# generate gaussians
|
154 |
+
gaussians, gaussians_orig_res = model.forward_gaussians_downsample(input_image)
|
155 |
|
156 |
# save gaussians
|
157 |
model.gs.save_ply(gaussians, os.path.join('logs', name + '_model.ply'))
|
|
|
160 |
images = []
|
161 |
elevation = 0
|
162 |
|
163 |
+
azimuth = np.arange(0, 360, 2, dtype=np.int32)
|
164 |
+
for azi in tqdm.tqdm(azimuth):
|
165 |
+
|
166 |
+
cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
|
167 |
+
|
168 |
+
cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
|
169 |
+
|
170 |
+
# cameras needed by gaussian rasterizer
|
171 |
+
cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
|
172 |
+
cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
|
173 |
+
cam_pos = - cam_poses[:, :3, 3] # [V, 3]
|
174 |
+
|
175 |
+
image = model.gs.render(gaussians_orig_res, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
|
176 |
+
images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
images = np.concatenate(images, axis=0)
|
179 |
imageio.mimwrite(os.path.join('vis_data', name + '_static.mp4'), images, fps=30)
|
main_4d_demo.py
CHANGED
@@ -540,38 +540,33 @@ class GUI:
|
|
540 |
|
541 |
# render eval
|
542 |
image_list =[]
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
for
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
if j >= self.opt.batch_size:
|
571 |
-
hor = (hor+delta_hor) % 360
|
572 |
-
|
573 |
-
|
574 |
-
imageio.mimwrite(f'vis_data/{self.opt.save_path}.mp4', image_list, fps=7)
|
575 |
|
576 |
if self.gui:
|
577 |
while True:
|
|
|
540 |
|
541 |
# render eval
|
542 |
image_list =[]
|
543 |
+
fps = 14
|
544 |
+
delta_time = 1 / 30
|
545 |
+
self.renderer.prepare_render_4x()
|
546 |
+
time = 0
|
547 |
+
for hor in range(720):
|
548 |
+
|
549 |
+
pose = orbit_camera(self.opt.elevation, hor, self.opt.radius)
|
550 |
+
cur_cam = MiniCam(
|
551 |
+
pose,
|
552 |
+
512,
|
553 |
+
512,
|
554 |
+
self.cam.fovy,
|
555 |
+
self.cam.fovx,
|
556 |
+
self.cam.near,
|
557 |
+
self.cam.far,
|
558 |
+
time=int(time * fps) % (self.opt.batch_size * 4)
|
559 |
+
)
|
560 |
+
with torch.no_grad():
|
561 |
+
outputs = self.renderer.render(cur_cam)
|
562 |
+
out = outputs["image"].cpu().detach().numpy().astype(np.float32)
|
563 |
+
out = np.transpose(out, (1, 2, 0))
|
564 |
+
out = np.uint8(out*255)
|
565 |
+
image_list.append(out)
|
566 |
+
time += delta_time
|
567 |
+
|
568 |
+
|
569 |
+
imageio.mimwrite(f'vis_data/{self.opt.save_path}.mp4', image_list, fps=30)
|
|
|
|
|
|
|
|
|
|
|
570 |
|
571 |
if self.gui:
|
572 |
while True:
|