AnnonSubmission commited on
Commit
ff3029d
1 Parent(s): 2a82b5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -70,13 +70,10 @@ def load_or_augment_images(img1_input, img2_input, use_aug):
70
 
71
  def run_occlusion(w_size, stride):
72
  heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
73
- heatmap1_ca, heatmap2_ca = occlusion_context_agnositc(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
74
  heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)
75
 
76
  added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
77
  added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
78
- added_image1_ca = overlay_heatmap(img1, heatmap1_ca, denormalize = denorm)
79
- added_image2_ca = overlay_heatmap(img2, heatmap2_ca, denormalize = denorm)
80
 
81
  fig, axs = plt.subplots(2, 4, figsize=(20,10))
82
  np.vectorize(lambda ax:ax.axis('off'))(axs)
@@ -84,14 +81,11 @@ def run_occlusion(w_size, stride):
84
  axs[0, 0].imshow(show_image(img1, denormalize = denorm))
85
  axs[0, 1].imshow(added_image1)
86
  axs[0, 1].set_title("Conditional Occlusion")
87
- axs[0, 2].imshow(added_image1_ca)
88
- axs[0, 2].set_title("CA Cond. Occlusion")
89
- axs[0, 3].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
90
- axs[0, 3].set_title("Pairwise Occlusion")
91
  axs[1, 0].imshow(show_image(img2, denormalize = denorm))
92
  axs[1, 1].imshow(added_image2)
93
- axs[1, 2].imshow(added_image2_ca)
94
- axs[1, 3].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))
95
  plt.subplots_adjust(wspace=0, hspace = 0.01)
96
  pil_output = fig2img(fig)
97
  return pil_output
 
70
 
71
  def run_occlusion(w_size, stride):
72
  heatmap1, heatmap2 = occlusion(img1, img2, ssl_model, w_size = 64, stride = 8, batch_size = 32)
 
73
  heatmap1_po, heatmap2_po = pairwise_occlusion(img1, img2, ssl_model, batch_size = 32, erase_scale = (0.1, 0.3), erase_ratio = (1, 1.5), num_erases = 100)
74
 
75
  added_image1 = overlay_heatmap(img1, heatmap1, denormalize = denorm)
76
  added_image2 = overlay_heatmap(img2, heatmap2, denormalize = denorm)
 
 
77
 
78
  fig, axs = plt.subplots(2, 4, figsize=(20,10))
79
  np.vectorize(lambda ax:ax.axis('off'))(axs)
 
81
  axs[0, 0].imshow(show_image(img1, denormalize = denorm))
82
  axs[0, 1].imshow(added_image1)
83
  axs[0, 1].set_title("Conditional Occlusion")
84
+ axs[0, 2].imshow((deprocess(img1, denormalize = denorm) * heatmap1_po[:,:,None]).astype('uint8'))
85
+ axs[0, 2].set_title("Pairwise Occlusion")
 
 
86
  axs[1, 0].imshow(show_image(img2, denormalize = denorm))
87
  axs[1, 1].imshow(added_image2)
88
+ axs[1, 2].imshow((deprocess(img2, denormalize = denorm) * heatmap2_po[:,:,None]).astype('uint8'))
 
89
  plt.subplots_adjust(wspace=0, hspace = 0.01)
90
  pil_output = fig2img(fig)
91
  return pil_output