wzhouxiff commited on
Commit
c305f12
1 Parent(s): 979cf8b

merge all the run function to app.py

Browse files
Files changed (2) hide show
  1. app.py +344 -7
  2. objctrl_2_5d/utils/ui_utils.py +0 -26
app.py CHANGED
@@ -8,14 +8,23 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
8
  from omegaconf import OmegaConf
9
  from PIL import Image
10
  import numpy as np
 
 
11
 
12
- from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, run_depth, get_points, undo_points, mask_image
 
 
 
13
 
 
 
14
 
15
  from cameractrl.inference import get_pipeline
16
- from objctrl_2_5d.objctrl_2_5d import run
17
  from objctrl_2_5d.utils.examples import examples, sync_points
18
 
 
 
 
19
 
20
  ### Title and Description ###
21
  #### Description ####
@@ -118,7 +127,6 @@ pipeline = get_pipeline(model_id, "unet", model_config['down_block_types'], mode
118
 
119
  ### run the demo ##
120
  @spaces.GPU(duration=50)
121
- # def run_segment(segmentor):
122
  def segment(canvas, image, logits):
123
  if logits is not None:
124
  logits *= 32.0
@@ -159,8 +167,338 @@ def segment(canvas, image, logits):
159
 
160
  return mask[0], masked_img, masked_img, logits / 32.0
161
 
162
- # return segment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
 
 
 
 
164
 
165
  # -------------- UI definition --------------
166
  with gr.Blocks() as demo:
@@ -317,14 +655,13 @@ with gr.Blocks() as demo:
317
  )
318
 
319
  select_button.click(
320
- # run_segment(segmentor),
321
  segment,
322
  [canvas, original_image, mask_logits],
323
  [mask, mask_output, masked_original_image, mask_logits]
324
  )
325
 
326
  depth_button.click(
327
- run_depth(d_model_NK),
328
  [original_image, selected_points],
329
  [depth, depth_image, org_depth_image]
330
  )
@@ -347,7 +684,7 @@ with gr.Blocks() as demo:
347
  )
348
 
349
  generated_button.click(
350
- run(pipeline, device),
351
  [
352
  original_image,
353
  mask,
 
8
  from omegaconf import OmegaConf
9
  from PIL import Image
10
  import numpy as np
11
+ from copy import deepcopy
12
+ import cv2
13
 
14
+ import torch.nn.functional as F
15
+ import torchvision
16
+ from einops import rearrange
17
+ import tempfile
18
 
19
+ from objctrl_2_5d.utils.ui_utils import process_image, get_camera_pose, get_subject_points, get_points, undo_points, mask_image
20
+ from ZoeDepth.zoedepth.utils.misc import colorize
21
 
22
  from cameractrl.inference import get_pipeline
 
23
  from objctrl_2_5d.utils.examples import examples, sync_points
24
 
25
+ from objctrl_2_5d.utils.objmask_util import RT2Plucker, Unprojected, roll_with_ignore_multidim, dilate_mask_pytorch
26
+ from objctrl_2_5d.utils.filter_utils import get_freq_filter, freq_mix_3d
27
+
28
 
29
  ### Title and Description ###
30
  #### Description ####
 
127
 
128
  ### run the demo ##
129
  @spaces.GPU(duration=50)
 
130
  def segment(canvas, image, logits):
131
  if logits is not None:
132
  logits *= 32.0
 
167
 
168
  return mask[0], masked_img, masked_img, logits / 32.0
169
 
170
+ @spaces.GPU(duration=50)
171
+ def get_depth(image, points):
172
+
173
+ depth = d_model_NK.infer_pil(image)
174
+ colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
175
+
176
+ depth_img = deepcopy(colored_depth[:, :, :3])
177
+ if len(points) > 0:
178
+ for idx, point in enumerate(points):
179
+ if idx % 2 == 0:
180
+ cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
181
+ else:
182
+ cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
183
+ if idx > 0:
184
+ cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
185
+
186
+ return depth, depth_img, colored_depth[:, :, :3]
187
+
188
+
189
+ @spaces.GPU(duration=50)
190
+ def run_objctrl_2_5d(condition_image,
191
+ mask,
192
+ depth,
193
+ RTs,
194
+ bg_mode,
195
+ shared_wapring_latents,
196
+ scale_wise_masks,
197
+ rescale,
198
+ seed,
199
+ ds, dt,
200
+ num_inference_steps=25):
201
+
202
+ DEBUG = False
203
+
204
+ if DEBUG:
205
+ cur_OUTPUT_PATH = 'outputs/tmp'
206
+ os.makedirs(cur_OUTPUT_PATH, exist_ok=True)
207
+
208
+ # num_inference_steps=25
209
+ min_guidance_scale = 1.0
210
+ max_guidance_scale = 3.0
211
+
212
+ area_ratio = 0.3
213
+ depth_scale_ = 5.2
214
+ center_margin = 10
215
+
216
+ height, width = 320, 576
217
+ num_frames = 14
218
+
219
+ intrinsics = np.array([[float(width), float(width), float(width) / 2, float(height) / 2]])
220
+ intrinsics = np.repeat(intrinsics, num_frames, axis=0) # [n_frame, 4]
221
+ fx = intrinsics[0, 0] / width
222
+ fy = intrinsics[0, 1] / height
223
+ cx = intrinsics[0, 2] / width
224
+ cy = intrinsics[0, 3] / height
225
+
226
+ down_scale = 8
227
+ H, W = height // down_scale, width // down_scale
228
+ K = np.array([[width / down_scale, 0, W / 2], [0, width / down_scale, H / 2], [0, 0, 1]])
229
+
230
+ seed = int(seed)
231
+
232
+ center_h_margin, center_w_margin = center_margin, center_margin
233
+ depth_center = np.mean(depth[height//2-center_h_margin:height//2+center_h_margin, width//2-center_w_margin:width//2+center_w_margin])
234
+
235
+ if rescale > 0:
236
+ depth_rescale = round(depth_scale_ * rescale / depth_center, 2)
237
+ else:
238
+ depth_rescale = 1.0
239
+
240
+ depth = depth * depth_rescale
241
+
242
+ depth_down = F.interpolate(torch.tensor(depth).unsqueeze(0).unsqueeze(0),
243
+ (H, W), mode='bilinear', align_corners=False).squeeze().numpy() # [H, W]
244
+
245
+ ## latent
246
+ generator = torch.Generator()
247
+ generator.manual_seed(seed)
248
+
249
+ latents_org = pipeline.prepare_latents(
250
+ 1,
251
+ 14,
252
+ 8,
253
+ height,
254
+ width,
255
+ pipeline.dtype,
256
+ device,
257
+ generator,
258
+ None,
259
+ )
260
+ latents_org = latents_org / pipeline.scheduler.init_noise_sigma
261
+
262
+ cur_plucker_embedding, _, _ = RT2Plucker(RTs, RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
263
+ cur_plucker_embedding = cur_plucker_embedding.to(device)
264
+ cur_plucker_embedding = cur_plucker_embedding[None, ...] # b 6 f h w
265
+ cur_plucker_embedding = cur_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
266
+ cur_plucker_embedding = cur_plucker_embedding[:, :num_frames, ...]
267
+ cur_pose_features = pipeline.pose_encoder(cur_plucker_embedding)
268
+
269
+ # bg_mode = ["Fixed", "Reverse", "Free"]
270
+ if bg_mode == "Fixed":
271
+ fix_RTs = np.repeat(RTs[0][None, ...], num_frames, axis=0) # [n_frame, 4, 3]
272
+ fix_plucker_embedding, _, _ = RT2Plucker(fix_RTs, num_frames, (height, width), fx, fy, cx, cy) # 6, V, H, W
273
+ fix_plucker_embedding = fix_plucker_embedding.to(device)
274
+ fix_plucker_embedding = fix_plucker_embedding[None, ...] # b 6 f h w
275
+ fix_plucker_embedding = fix_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
276
+ fix_plucker_embedding = fix_plucker_embedding[:, :num_frames, ...]
277
+ fix_pose_features = pipeline.pose_encoder(fix_plucker_embedding)
278
+
279
+ elif bg_mode == "Reverse":
280
+ bg_plucker_embedding, _, _ = RT2Plucker(RTs[::-1], RTs.shape[0], (height, width), fx, fy, cx, cy) # 6, V, H, W
281
+ bg_plucker_embedding = bg_plucker_embedding.to(device)
282
+ bg_plucker_embedding = bg_plucker_embedding[None, ...] # b 6 f h w
283
+ bg_plucker_embedding = bg_plucker_embedding.permute(0, 2, 1, 3, 4) # b f 6 h w
284
+ bg_plucker_embedding = bg_plucker_embedding[:, :num_frames, ...]
285
+ fix_pose_features = pipeline.pose_encoder(bg_plucker_embedding)
286
+
287
+ else:
288
+ fix_pose_features = None
289
+
290
+ #### preparing mask
291
+
292
+ mask = Image.fromarray(mask)
293
+ mask = mask.resize((W, H))
294
+ mask = np.array(mask).astype(np.float32)
295
+ mask = np.expand_dims(mask, axis=-1)
296
+
297
+ # visulize mask
298
+ if DEBUG:
299
+ mask_sum_vis = mask[..., 0]
300
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
301
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
302
+
303
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/org_mask.png')
304
+
305
+ try:
306
+ warped_masks = Unprojected(mask, depth_down, RTs, H=H, W=W, K=K)
307
+
308
+ warped_masks.insert(0, mask)
309
+
310
+ except:
311
+ # mask to bbox
312
+ print(f'!!! Mask is too small to warp; mask to bbox')
313
+ mask = mask[:, :, 0]
314
+ coords = cv2.findNonZero(mask)
315
+ x, y, w, h = cv2.boundingRect(coords)
316
+ # mask[y:y+h, x:x+w] = 1.0
317
+
318
+ center_x, center_y = x + w // 2, y + h // 2
319
+ center_z = depth_down[center_y, center_x]
320
+
321
+ # RTs [n_frame, 3, 4] to [n_frame, 4, 4] , add [0, 0, 0, 1]
322
+ RTs = np.concatenate([RTs, np.array([[[0, 0, 0, 1]]] * num_frames)], axis=1)
323
+
324
+ # RTs: world to camera
325
+ P0 = np.array([center_x, center_y, 1])
326
+ Pc0 = np.linalg.inv(K) @ P0 * center_z
327
+ pw = np.linalg.inv(RTs[0]) @ np.array([Pc0[0], Pc0[1], center_z, 1]) # [4]
328
+
329
+ P = [np.array([center_x, center_y])]
330
+ for i in range(1, num_frames):
331
+ Pci = RTs[i] @ pw
332
+ Pi = K @ Pci[:3] / Pci[2]
333
+ P.append(Pi[:2])
334
+
335
+ warped_masks = [mask]
336
+ for i in range(1, num_frames):
337
+ shift_x = int(round(P[i][0] - P[0][0]))
338
+ shift_y = int(round(P[i][1] - P[0][1]))
339
+
340
+ cur_mask = roll_with_ignore_multidim(mask, [shift_y, shift_x])
341
+ warped_masks.append(cur_mask)
342
+
343
+
344
+ warped_masks = [v[..., None] for v in warped_masks]
345
+
346
+ warped_masks = np.stack(warped_masks, axis=0) # [f, h, w]
347
+ warped_masks = np.repeat(warped_masks, 3, axis=-1) # [f, h, w, 3]
348
+
349
+ mask_sum = np.sum(warped_masks, axis=0, keepdims=True) # [1, H, W, 3]
350
+ mask_sum[mask_sum > 1.0] = 1.0
351
+ mask_sum = mask_sum[0,:,:, 0]
352
+
353
+ if DEBUG:
354
+ ## visulize warp mask
355
+ warp_masks_vis = torch.tensor(warped_masks)
356
+ warp_masks_vis = (warp_masks_vis * 255.0).to(torch.uint8)
357
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warped_masks.mp4', warp_masks_vis, fps=10, video_codec='h264', options={'crf': '10'})
358
+
359
+ # visulize mask
360
+ mask_sum_vis = mask_sum
361
+ mask_sum_vis = (mask_sum_vis * 255.0).astype(np.uint8)
362
+ mask_sum_vis = Image.fromarray(mask_sum_vis)
363
+
364
+ mask_sum_vis.save(f'{cur_OUTPUT_PATH}/merged_mask.png')
365
+
366
+ if scale_wise_masks:
367
+ min_area = H * W * area_ratio # cal in downscale
368
+ non_zero_len = mask_sum.sum()
369
+
370
+ print(f'non_zero_len: {non_zero_len}, min_area: {min_area}')
371
+
372
+ if non_zero_len > min_area:
373
+ kernel_sizes = [1, 1, 1, 3]
374
+ elif non_zero_len > min_area * 0.5:
375
+ kernel_sizes = [3, 1, 1, 5]
376
+ else:
377
+ kernel_sizes = [5, 3, 3, 7]
378
+ else:
379
+ kernel_sizes = [1, 1, 1, 1]
380
+
381
+ mask = torch.from_numpy(mask_sum) # [h, w]
382
+ mask = mask[None, None, ...] # [1, 1, h, w]
383
+ mask = F.interpolate(mask, (height, width), mode='bilinear', align_corners=False) # [1, 1, H, W]
384
+ # mask = mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
385
+ mask = mask.to(pipeline.dtype).to(device)
386
+
387
+ ##### Mask End ######
388
+
389
+ ### Got blending pose features Start ###
390
+
391
+ pose_features = []
392
+ for i in range(0, len(cur_pose_features)):
393
+ kernel_size = kernel_sizes[i]
394
+ h, w = cur_pose_features[i].shape[-2:]
395
+
396
+ if fix_pose_features is None:
397
+ pose_features.append(torch.zeros_like(cur_pose_features[i]))
398
+ else:
399
+ pose_features.append(fix_pose_features[i])
400
+
401
+ cur_mask = F.interpolate(mask, (h, w), mode='bilinear', align_corners=False)
402
+ cur_mask = dilate_mask_pytorch(cur_mask, kernel_size=kernel_size) # [1, 1, H, W]
403
+ cur_mask = cur_mask.repeat(1, num_frames, 1, 1) # [1, f, H, W]
404
+
405
+ if DEBUG:
406
+ # visulize mask
407
+ mask_vis = cur_mask[0, 0].cpu().numpy() * 255.0
408
+ mask_vis = Image.fromarray(mask_vis.astype(np.uint8))
409
+ mask_vis.save(f'{cur_OUTPUT_PATH}/mask_k{kernel_size}_scale{i}.png')
410
+
411
+ cur_mask = cur_mask[None, ...] # [1, 1, f, H, W]
412
+ pose_features[-1] = cur_pose_features[i] * cur_mask + pose_features[-1] * (1 - cur_mask)
413
+
414
+ ### Got blending pose features End ###
415
+
416
+ ##### Warp Noise Start ######
417
+
418
+ if shared_wapring_latents:
419
+ noise = latents_org[0, 0].data.cpu().numpy().copy() #[14, 4, 40, 72]
420
+ noise = np.transpose(noise, (1, 2, 0)) # [40, 72, 4]
421
+
422
+ try:
423
+ warp_noise = Unprojected(noise, depth_down, RTs, H=H, W=W, K=K)
424
+ warp_noise.insert(0, noise)
425
+ except:
426
+ print(f'!!! Noise is too small to warp; mask to bbox')
427
+
428
+ warp_noise = [noise]
429
+ for i in range(1, num_frames):
430
+ shift_x = int(round(P[i][0] - P[0][0]))
431
+ shift_y = int(round(P[i][1] - P[0][1]))
432
+
433
+ cur_noise= roll_with_ignore_multidim(noise, [shift_y, shift_x])
434
+ warp_noise.append(cur_noise)
435
+
436
+ warp_noise = np.stack(warp_noise, axis=0) # [f, h, w, 4]
437
+
438
+ if DEBUG:
439
+ ## visulize warp noise
440
+ warp_noise_vis = torch.tensor(warp_noise)[..., :3] * torch.tensor(warped_masks)
441
+ warp_noise_vis = (warp_noise_vis - warp_noise_vis.min()) / (warp_noise_vis.max() - warp_noise_vis.min())
442
+ warp_noise_vis = (warp_noise_vis * 255.0).to(torch.uint8)
443
+
444
+ torchvision.io.write_video(f'{cur_OUTPUT_PATH}/warp_noise.mp4', warp_noise_vis, fps=10, video_codec='h264', options={'crf': '10'})
445
+
446
+
447
+ warp_latents = torch.tensor(warp_noise).permute(0, 3, 1, 2).to(latents_org.device).to(latents_org.dtype) # [frame, 4, H, W]
448
+ warp_latents = warp_latents.unsqueeze(0) # [1, frame, 4, H, W]
449
+
450
+ warped_masks = torch.tensor(warped_masks).permute(0, 3, 1, 2).unsqueeze(0) # [1, frame, 3, H, W]
451
+ mask_extend = torch.concat([warped_masks, warped_masks[:,:,0:1]], dim=2) # [1, frame, 4, H, W]
452
+ mask_extend = mask_extend.to(latents_org.device).to(latents_org.dtype)
453
+
454
+ warp_latents = warp_latents * mask_extend + latents_org * (1 - mask_extend)
455
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
456
+ random_noise = latents_org.clone().permute(0, 2, 1, 3, 4)
457
+
458
+ filter_shape = warp_latents.shape
459
+
460
+ freq_filter = get_freq_filter(
461
+ filter_shape,
462
+ device = device,
463
+ filter_type='butterworth',
464
+ n=4,
465
+ d_s=ds,
466
+ d_t=dt
467
+ )
468
+
469
+ warp_latents = freq_mix_3d(warp_latents, random_noise, freq_filter)
470
+ warp_latents = warp_latents.permute(0, 2, 1, 3, 4)
471
+
472
+ else:
473
+ warp_latents = latents_org.clone()
474
+
475
+ generator.manual_seed(42)
476
+
477
+ with torch.no_grad():
478
+ result = pipeline(
479
+ image=condition_image,
480
+ pose_embedding=cur_plucker_embedding,
481
+ height=height,
482
+ width=width,
483
+ num_frames=num_frames,
484
+ num_inference_steps=num_inference_steps,
485
+ min_guidance_scale=min_guidance_scale,
486
+ max_guidance_scale=max_guidance_scale,
487
+ do_image_process=True,
488
+ generator=generator,
489
+ output_type='pt',
490
+ pose_features= pose_features,
491
+ latents = warp_latents
492
+ ).frames[0].cpu() #[f, c, h, w]
493
+
494
+
495
+ result = rearrange(result, 'f c h w -> f h w c')
496
+ result = (result * 255.0).to(torch.uint8)
497
 
498
+ video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
499
+ torchvision.io.write_video(video_path, result, fps=10, video_codec='h264', options={'crf': '8'})
500
+
501
+ return video_path
502
 
503
  # -------------- UI definition --------------
504
  with gr.Blocks() as demo:
 
655
  )
656
 
657
  select_button.click(
 
658
  segment,
659
  [canvas, original_image, mask_logits],
660
  [mask, mask_output, masked_original_image, mask_logits]
661
  )
662
 
663
  depth_button.click(
664
+ get_depth,
665
  [original_image, selected_points],
666
  [depth, depth_image, org_depth_image]
667
  )
 
684
  )
685
 
686
  generated_button.click(
687
+ run_objctrl_2_5d,
688
  [
689
  original_image,
690
  mask,
objctrl_2_5d/utils/ui_utils.py CHANGED
@@ -1,14 +1,9 @@
1
- import spaces
2
-
3
  import gradio as gr
4
  from PIL import Image
5
  import numpy as np
6
 
7
  from copy import deepcopy
8
  import cv2
9
- import torch
10
-
11
- from ZoeDepth.zoedepth.utils.misc import colorize
12
 
13
  from objctrl_2_5d.utils.vis_camera import vis_camera_rescale
14
  from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
@@ -102,27 +97,6 @@ def get_points(img,
102
  def undo_points(original_image):
103
  return original_image, []
104
 
105
- @spaces.GPU(duration=50)
106
- def run_depth(d_model_NK):
107
- def get_depth(image, points):
108
-
109
- depth = d_model_NK.infer_pil(image)
110
- colored_depth = colorize(depth, cmap='gray_r') # [h, w, 4] 0-255
111
-
112
- depth_img = deepcopy(colored_depth[:, :, :3])
113
- if len(points) > 0:
114
- for idx, point in enumerate(points):
115
- if idx % 2 == 0:
116
- cv2.circle(depth_img, tuple(point), 10, (255, 0, 0), -1)
117
- else:
118
- cv2.circle(depth_img, tuple(point), 10, (0, 0, 255), -1)
119
- if idx > 0:
120
- cv2.arrowedLine(depth_img, points[idx-1], points[idx], (255, 255, 255), 4, tipLength=0.5)
121
-
122
- return depth, depth_img, colored_depth[:, :, :3]
123
-
124
- return get_depth
125
-
126
 
127
  def interpolate_points(points, num_points):
128
  x = points[:, 0]
 
 
 
1
  import gradio as gr
2
  from PIL import Image
3
  import numpy as np
4
 
5
  from copy import deepcopy
6
  import cv2
 
 
 
7
 
8
  from objctrl_2_5d.utils.vis_camera import vis_camera_rescale
9
  from objctrl_2_5d.utils.objmask_util import trajectory_to_camera_poses_v1
 
97
  def undo_points(original_image):
98
  return original_image, []
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  def interpolate_points(points, num_points):
102
  x = points[:, 0]