Martijn van Beers commited on
Commit
5f8002c
1 Parent(s): 8f3d1af

Make show_image_relevance dimension-agnostic

Browse files

The different CLIP models have different dimensions for the
image relevance vector (and the image size). Instead of hardcoding
for a specific model, calculate the numbers we need.

Files changed (1) hide show
  1. CLIP_explainability/utils.py +5 -3
CLIP_explainability/utils.py CHANGED
@@ -78,9 +78,11 @@ def show_image_relevance(image_relevance, image, orig_image, device):
78
  cam = cam / np.max(cam)
79
  return cam
80
 
81
- image_relevance = image_relevance.reshape(1, 1, 7, 7)
82
- image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
83
- image_relevance = image_relevance.reshape(224, 224).to(device).data.cpu().numpy()
 
 
84
  image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
85
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
86
  image = (image - image.min()) / (image.max() - image.min())
 
78
  cam = cam / np.max(cam)
79
  return cam
80
 
81
+ rel_shp = np.sqrt(image_relevance.shape[0]).astype(int)
82
+ img_size = image.shape[-1]
83
+ image_relevance = image_relevance.reshape(1, 1, rel_shp, rel_shp)
84
+ image_relevance = torch.nn.functional.interpolate(image_relevance, size=img_size, mode='bilinear')
85
+ image_relevance = image_relevance.reshape(img_size, img_size).data.cpu().numpy()
86
  image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
87
  image = image[0].permute(1, 2, 0).data.cpu().numpy()
88
  image = (image - image.min()) / (image.max() - image.min())