Martijn van Beers commited on
Commit
8f3d1af
1 Parent(s): cf1865f

Remove code for jupyter notebooks

Browse files

There was some partially commented out code to create matplotlib
figures. Remove it altogether.

Files changed (2) hide show
  1. CLIP_explainability/utils.py +7 -25
  2. app.py +2 -2
CLIP_explainability/utils.py CHANGED
@@ -69,7 +69,7 @@ def interpret(image, texts, model, device):
69
  return text_relevance, image_relevance
70
 
71
 
72
- def show_image_relevance(image_relevance, image, orig_image, device, show=True):
73
  # create heatmap from mask on image
74
  def show_cam_on_image(img, mask):
75
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
@@ -78,15 +78,6 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
78
  cam = cam / np.max(cam)
79
  return cam
80
 
81
- # plt.axis('off')
82
- # f, axarr = plt.subplots(1,2)
83
- # axarr[0].imshow(orig_image)
84
-
85
- if show:
86
- fig, axs = plt.subplots(1, 2)
87
- axs[0].imshow(orig_image);
88
- axs[0].axis('off');
89
-
90
  image_relevance = image_relevance.reshape(1, 1, 7, 7)
91
  image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
92
  image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
@@ -97,16 +88,10 @@ def show_image_relevance(image_relevance, image, orig_image, device, show=True):
97
  vis = np.uint8(255 * vis)
98
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
99
 
100
- if show:
101
- # axar[1].imshow(vis)
102
- axs[1].imshow(vis);
103
- axs[1].axis('off');
104
- # plt.imshow(vis)
105
-
106
  return image_relevance
107
 
108
 
109
- def show_heatmap_on_text(text, text_encoding, R_text, show=True):
110
  CLS_idx = text_encoding.argmax(dim=-1)
111
  R_text = R_text[CLS_idx, 1:CLS_idx]
112
  text_scores = R_text / R_text.sum()
@@ -115,19 +100,16 @@ def show_heatmap_on_text(text, text_encoding, R_text, show=True):
115
  text_tokens=_tokenizer.encode(text)
116
  text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
117
  vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
118
-
119
- if show:
120
- visualization.visualize_text(vis_data_records)
121
 
122
  return text_scores, text_tokens_decoded
123
 
124
 
125
- def show_img_heatmap(image_relevance, image, orig_image, device, show=True):
126
- return show_image_relevance(image_relevance, image, orig_image, device, show=show)
127
 
128
 
129
- def show_txt_heatmap(text, text_encoding, R_text, show=True):
130
- return show_heatmap_on_text(text, text_encoding, R_text, show=show)
131
 
132
 
133
  def load_dataset():
@@ -149,4 +131,4 @@ class color:
149
  RED = '\033[91m'
150
  BOLD = '\033[1m'
151
  UNDERLINE = '\033[4m'
152
- END = '\033[0m'
 
69
  return text_relevance, image_relevance
70
 
71
 
72
+ def show_image_relevance(image_relevance, image, orig_image, device):
73
  # create heatmap from mask on image
74
  def show_cam_on_image(img, mask):
75
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
 
78
  cam = cam / np.max(cam)
79
  return cam
80
 
 
 
 
 
 
 
 
 
 
81
  image_relevance = image_relevance.reshape(1, 1, 7, 7)
82
  image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
83
  image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
 
88
  vis = np.uint8(255 * vis)
89
  vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
90
 
 
 
 
 
 
 
91
  return image_relevance
92
 
93
 
94
+ def show_heatmap_on_text(text, text_encoding, R_text):
95
  CLS_idx = text_encoding.argmax(dim=-1)
96
  R_text = R_text[CLS_idx, 1:CLS_idx]
97
  text_scores = R_text / R_text.sum()
 
100
  text_tokens=_tokenizer.encode(text)
101
  text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
102
  vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
 
 
 
103
 
104
  return text_scores, text_tokens_decoded
105
 
106
 
107
+ def show_img_heatmap(image_relevance, image, orig_image, device):
108
+ return show_image_relevance(image_relevance, image, orig_image, device)
109
 
110
 
111
+ def show_txt_heatmap(text, text_encoding, R_text):
112
+ return show_heatmap_on_text(text, text_encoding, R_text)
113
 
114
 
115
  def load_dataset():
 
131
  RED = '\033[91m'
132
  BOLD = '\033[1m'
133
  UNDERLINE = '\033[4m'
134
+ END = '\033[0m'
app.py CHANGED
@@ -59,10 +59,10 @@ def run_demo(image, text):
59
 
60
  R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
61
 
62
- image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device, show=False)
63
  overlapped = overlay_relevance_map_on_image(image, image_relevance)
64
 
65
- text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0], show=False)
66
 
67
  highlighted_text = []
68
  for i, token in enumerate(text_tokens_decoded):
 
59
 
60
  R_text, R_image = interpret(model=model, image=img, texts=text_input, device=device)
61
 
62
+ image_relevance = show_img_heatmap(R_image[0], img, orig_image=orig_image, device=device)
63
  overlapped = overlay_relevance_map_on_image(image, image_relevance)
64
 
65
+ text_scores, text_tokens_decoded = show_heatmap_on_text(text, text_input, R_text[0])
66
 
67
  highlighted_text = []
68
  for i, token in enumerate(text_tokens_decoded):