jwyang commited on
Commit
ad7aaa6
1 Parent(s): a574e10

add heatmap visualization

Browse files
app.py CHANGED
@@ -118,11 +118,20 @@ def recognize_image(image, texts):
118
  text_embeddings = model.get_text_embeddings(texts.split(';'))
119
 
120
  # compute output
121
- feat_img = model.encode_image(img_t.unsqueeze(0))
122
  output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
  prediction = output.softmax(-1).flatten()
124
 
125
- return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  image = gr.inputs.Image()
@@ -132,8 +141,11 @@ gr.Interface(
132
  description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
133
  fn=recognize_image,
134
  inputs=["image", "text"],
135
- outputs=[
136
- label,
 
 
 
137
  ],
138
  examples=[
139
  ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
 
118
  text_embeddings = model.get_text_embeddings(texts.split(';'))
119
 
120
  # compute output
121
+ feat_img, feat_map = model.encode_image(img_t.unsqueeze(0), output_map=True)
122
  output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
  prediction = output.softmax(-1).flatten()
124
 
125
+ # generate feat map given the top matched texts
126
+ output_map = (feat_map * text_embeddings[prediction.argmax()].unsqueeze(-1)).sum(1).softmax(-1)
127
+ output_map = output_map.view(1, 1, 7, 7)
128
+
129
+ output_map = nn.Upsample(size=img_t.shape[1:], mode='bilinear')(output_map)
130
+ output_map = output_map.squeeze(1).detach().permute(1, 2, 0).numpy()
131
+ output_map = (output_map - output_map.min()) / (output_map.max() - output_map.min())
132
+ heatmap = show_cam_on_image(img_d, output_map, use_rgb=True)
133
+
134
+ return Image.fromarray(heatmap), {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
135
 
136
 
137
  image = gr.inputs.Image()
 
141
  description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
142
  fn=recognize_image,
143
  inputs=["image", "text"],
144
+ outputs=[
145
+ gr.outputs.Image(
146
+ type="pil",
147
+ label="zero-shot heat map"),
148
+ label,
149
  ],
150
  examples=[
151
  ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
model/image_encoder/swin_transformer.py CHANGED
@@ -557,7 +557,7 @@ class SwinTransformer(nn.Module):
557
  def no_weight_decay_keywords(self):
558
  return {'relative_position_bias_table'}
559
 
560
- def forward_features(self, x):
561
  x = self.patch_embed(x)
562
  if self.ape:
563
  x = x + self.absolute_pos_embed
@@ -566,10 +566,14 @@ class SwinTransformer(nn.Module):
566
  for layer in self.layers:
567
  x = layer(x)
568
 
569
- x = self.norm(x) # B L C
570
- x = self.avgpool(x.transpose(1, 2)) # B C 1
571
  x = torch.flatten(x, 1)
572
- return x
 
 
 
 
573
 
574
  def forward(self, x):
575
  x = self.forward_features(x)
 
557
  def no_weight_decay_keywords(self):
558
  return {'relative_position_bias_table'}
559
 
560
+ def forward_features(self, x, output_map=False):
561
  x = self.patch_embed(x)
562
  if self.ape:
563
  x = x + self.absolute_pos_embed
 
566
  for layer in self.layers:
567
  x = layer(x)
568
 
569
+ x_map = self.norm(x).transpose(1, 2) # B C L
570
+ x = self.avgpool(x_map) # B C 1
571
  x = torch.flatten(x, 1)
572
+
573
+ if output_map:
574
+ return x, x_map
575
+ else:
576
+ return x
577
 
578
  def forward(self, x):
579
  x = self.forward_features(x)
model/model.py CHANGED
@@ -153,14 +153,25 @@ class UniCLModel(nn.Module):
153
  imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
154
  return imnet_text_embeddings
155
 
156
- def encode_image(self, image, norm=True):
157
- x = self.image_encoder.forward_features(image)
 
 
 
158
  x = x @ self.image_projection
159
 
 
 
 
160
  if norm:
161
  x = x / x.norm(dim=-1, keepdim=True)
162
-
163
- return x
 
 
 
 
 
164
 
165
  def encode_text(self, text, norm=True):
166
  x = self.text_encoder(**text)
 
153
  imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
154
  return imnet_text_embeddings
155
 
156
+ def encode_image(self, image, norm=True, output_map=False):
157
+ x = self.image_encoder.forward_features(image, output_map=output_map)
158
+ if output_map:
159
+ x, x_map = x
160
+
161
  x = x @ self.image_projection
162
 
163
+ if output_map:
164
+ x_map = self.image_projection.unsqueeze(0).transpose(1, 2) @ x_map
165
+
166
  if norm:
167
  x = x / x.norm(dim=-1, keepdim=True)
168
+ if output_map:
169
+ x_map = x_map / x_map.norm(dim=1, keepdim=True)
170
+
171
+ if output_map:
172
+ return x, x_map
173
+ else:
174
+ return x
175
 
176
  def encode_text(self, text, norm=True):
177
  x = self.text_encoder(**text)