huzey commited on
Commit
b8bf004
1 Parent(s): 48bc263

update plotting fg

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -312,7 +312,8 @@ def blend_image_with_heatmap(image, heatmap, opacity1=0.5, opacity2=0.5):
312
  blended = (1 - opacity1) * image + opacity2 * heatmap
313
  return blended.astype(np.uint8)
314
 
315
-
 
316
  def segment_fg_bg(images):
317
 
318
  images = F.interpolate(images, (224, 224), mode="bilinear")
@@ -459,7 +460,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
459
  # fps_heatmaps[idx.item()] = heatmap.cpu()
460
  fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
461
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
462
- top10_image_idx[idx.item()] = mask_sort_idx[:6]
463
  # do the sorting
464
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
465
  fps_idx = fps_idx[_sort_idx]
@@ -486,7 +487,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
486
  if not advanced:
487
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
488
  if advanced:
489
- fig, axs = plt.subplots(6, 5, figsize=(15, 18))
490
  for ax in axs.flatten():
491
  ax.axis("off")
492
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
@@ -521,7 +522,7 @@ def make_cluster_plot(eigvecs, images, h=64, w=64, progess_start=0.6, advanced=F
521
 
522
  return fig_images, ret_magnitude
523
 
524
- def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
525
  heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
526
  heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
527
  heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
@@ -542,9 +543,15 @@ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64):
542
  bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
543
  other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
544
 
545
- fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=fg_idx, title="fg")
546
- bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=bg_idx, title="bg")
547
- other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=100, eig_idx=other_idx, title="other")
 
 
 
 
 
 
548
 
549
  cluster_images = fg_images + bg_images + other_images
550
 
@@ -833,7 +840,7 @@ def ncut_run(
833
  if advanced:
834
  cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
835
  else:
836
- cluster_images, eig_magnitude = make_cluster_plot(eigvecs, _images, h=h, w=w, progess_start=progress_start, advanced=False)
837
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
838
 
839
  norm_images = None
@@ -1859,7 +1866,7 @@ with demo:
1859
 
1860
  with gr.Column(scale=5, min_width=200):
1861
  output_gallery = make_output_images_section()
1862
- cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=False, interactive=False)
1863
  [
1864
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1865
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,
 
312
  blended = (1 - opacity1) * image + opacity2 * heatmap
313
  return blended.astype(np.uint8)
314
 
315
+ # preload the model
316
+ load_model("CLIP(ViT-B-16/openai)")
317
  def segment_fg_bg(images):
318
 
319
  images = F.interpolate(images, (224, 224), mode="bilinear")
 
460
  # fps_heatmaps[idx.item()] = heatmap.cpu()
461
  fps_heatmaps[idx.item()] = heatmap[mask_sort_idx[:6]].cpu()
462
  top3_image_idx[idx.item()] = mask_sort_idx[:3]
463
+ top10_image_idx[idx.item()] = mask_sort_idx[:5]
464
  # do the sorting
465
  _sort_idx = torch.tensor(sort_values).argsort(descending=True)
466
  fps_idx = fps_idx[_sort_idx]
 
487
  if not advanced:
488
  fig, axs = plt.subplots(3, 5, figsize=(15, 9))
489
  if advanced:
490
+ fig, axs = plt.subplots(5, 5, figsize=(15, 15))
491
  for ax in axs.flatten():
492
  ax.axis("off")
493
  for j, idx in enumerate(fps_idx[i_fig*5:i_fig*5+5]):
 
522
 
523
  return fig_images, ret_magnitude
524
 
525
+ def make_cluster_plot_advanced(eigvecs, images, h=64, w=64, num_fg=100, num_bg=10, num_other=0, small_title=True):
526
  heatmap_fg, heatmap_bg = segment_fg_bg(images.clone())
527
  heatmap_bg = rearrange(heatmap_bg, 'b h w c -> b c h w')
528
  heatmap_fg = rearrange(heatmap_fg, 'b h w c -> b c h w')
 
543
  bg_idx = torch.arange(heatmap_bg.shape[0])[bg_mask]
544
  other_idx = torch.arange(heatmap_fg.shape[0])[other_mask]
545
 
546
+ fg_images = []
547
+ if num_fg > 0:
548
+ fg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_fg, eig_idx=fg_idx, title="fg" if small_title else "cluster")
549
+ bg_images = []
550
+ if num_bg > 0:
551
+ bg_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_bg, eig_idx=bg_idx, title="bg" if small_title else "cluster")
552
+ other_images = []
553
+ if num_other > 0:
554
+ other_images, _ = make_cluster_plot(eigvecs, images, h=h, w=w, advanced=True, clusters=num_other, eig_idx=other_idx, title="other" if small_title else "cluster")
555
 
556
  cluster_images = fg_images + bg_images + other_images
557
 
 
840
  if advanced:
841
  cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w)
842
  else:
843
+ cluster_images, eig_magnitude = make_cluster_plot_advanced(eigvecs, _images, h=h, w=w, num_fg=20, num_bg=0, num_other=0, small_title=False)
844
  logging_str += f"plot time: {time.time() - start:.2f}s\n"
845
 
846
  norm_images = None
 
1866
 
1867
  with gr.Column(scale=5, min_width=200):
1868
  output_gallery = make_output_images_section()
1869
+ cluster_gallery = gr.Gallery(value=[], label="Clusters", show_label=True, elem_id="clusters", columns=[2], rows=[2], object_fit="contain", height="auto", show_share_button=True, preview=True, interactive=True)
1870
  [
1871
  model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
1872
  affinity_focal_gamma_slider, num_sample_ncut_slider, ncut_knn_slider, ncut_indirect_connection, ncut_make_orthogonal,