|
from typing import Callable, Iterable, Tuple |
|
import torch |
|
import numpy as np |
|
import PIL.Image |
|
import cv2 |
|
import wandb |
|
|
|
from tqdm import tqdm |
|
from pytorch_grad_cam import GradCAM |
|
|
|
from utils.val_loop_hook import ValidationLoopHook |
|
|
|
def _get_grad_cam_target(model): |
|
""" |
|
Determines the appropriate GradCAM target. |
|
""" |
|
|
|
|
|
if hasattr(model, "features"): |
|
return getattr(model, "features") |
|
|
|
pooling = [torch.nn.AdaptiveAvgPool1d, torch.nn.AvgPool1d, torch.nn.MaxPool1d, torch.nn.AdaptiveMaxPool1d, |
|
torch.nn.AdaptiveAvgPool2d, torch.nn.AvgPool2d, torch.nn.MaxPool2d, torch.nn.AdaptiveMaxPool2d, |
|
torch.nn.AdaptiveAvgPool3d, torch.nn.AvgPool3d, torch.nn.MaxPool3d, torch.nn.AdaptiveMaxPool3d] |
|
convolutions = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d] |
|
|
|
|
|
inverted_modules = list(model.modules())[::-1] |
|
for i, module in enumerate(inverted_modules): |
|
if any([isinstance(module, po) for po in pooling]): |
|
|
|
return inverted_modules[i+1] |
|
elif any([isinstance(module, co) for co in convolutions]): |
|
|
|
return module |
|
elif isinstance(module, torch.nn.Sequential): |
|
|
|
for child in list(module.children())[::-1]: |
|
sequential_result = _get_grad_cam_target(child) |
|
if sequential_result is not None: |
|
return sequential_result |
|
|
|
def _show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray: |
|
""" This function overlays the cam mask on the image as an heatmap. |
|
By default the heatmap is in BGR format. |
|
:param img: The base image in RGB or BGR format. |
|
:param mask: The cam mask. |
|
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format. |
|
:param colormap: The OpenCV colormap to be used. |
|
:returns: The default image with the cam overlay. |
|
""" |
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap) |
|
if use_rgb: |
|
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) |
|
heatmap = np.float32(heatmap) / 255 |
|
|
|
normalize = lambda x: (x - np.min(x))/np.ptp(x) |
|
|
|
cam = 0.6 * heatmap + normalize(img) |
|
cam = cam / np.max(cam) |
|
return np.uint8(255 * cam) |
|
|
|
def _strip_image_from_grid_row(row, gap=5, bg=255): |
|
strip = torch.full( |
|
(row.shape[0] * (row.shape[3] + gap) - gap, |
|
row.shape[1] * (row.shape[3] + gap) - gap, |
|
row.shape[4]), bg, dtype=row.dtype) |
|
for i in range(0, row.shape[0] * row.shape[1]): |
|
strip[(i // row.shape[1]) * (row.shape[2] + gap) : ((i // row.shape[1])+1) * (row.shape[2] + gap) - gap, |
|
(i % row.shape[1]) * (row.shape[3] + gap) : ((i % row.shape[1])+1) * (row.shape[3] + gap) - gap, |
|
:] = row[i // row.shape[1]][i % row.shape[1]] |
|
return PIL.Image.fromarray(strip.numpy()) |
|
|
|
|
|
class GradCAMBuilder(ValidationLoopHook): |
|
def __init__(self, image_shape: Iterable[int], target_category: int = None, num_images: int = 5): |
|
self.image_shape = image_shape |
|
self.target_category = target_category |
|
self.num_images = num_images |
|
|
|
self.targets = torch.zeros(self.num_images) |
|
self.activations = torch.zeros(self.num_images) |
|
self.images = torch.zeros(torch.Size([self.num_images]) + torch.Size(self.image_shape)) |
|
|
|
def process(self, batch, target_batch, logits_batch, prediction_batch): |
|
image_batch = batch["image"] |
|
|
|
with torch.no_grad(): |
|
if self.target_category is not None: |
|
local_activations = logits_batch[:, self.target_category] |
|
else: |
|
local_activations = torch.amax(logits_batch, dim=-1) |
|
|
|
|
|
target_match = (prediction_batch == target_batch) |
|
|
|
|
|
public = torch.tensor(["verse" in id for id in batch["verse_id"]]).type_as(target_match) |
|
|
|
mask = target_match & public |
|
|
|
if torch.max(mask) == False: |
|
|
|
return |
|
|
|
|
|
local_top_idx = torch.argsort(local_activations, descending=True) |
|
|
|
local_top_idx = local_top_idx[mask[local_top_idx]] |
|
current_idx = 0 |
|
|
|
while current_idx < self.num_images and local_activations[local_top_idx[current_idx]] > torch.min(self.activations): |
|
|
|
idx_to_replace = torch.argsort(self.activations)[0] |
|
|
|
self.activations[idx_to_replace] = local_activations[local_top_idx[current_idx]] |
|
self.images[idx_to_replace] = image_batch[local_top_idx[current_idx]] |
|
self.targets[idx_to_replace] = target_batch[local_top_idx[current_idx]] |
|
|
|
current_idx += 1 |
|
|
|
def trigger(self, module): |
|
model = module.backbone |
|
module.eval() |
|
|
|
|
|
grad_cam_target = _get_grad_cam_target(model) |
|
|
|
cam = GradCAM(model, [grad_cam_target], use_cuda=torch.cuda.is_available()) |
|
|
|
|
|
sorted_idx = torch.argsort(self.activations, descending=True) |
|
|
|
self.activations = self.activations[sorted_idx] |
|
self.images = self.images[sorted_idx] |
|
self.targets = self.targets[sorted_idx] |
|
|
|
|
|
|
|
grad_cams = cam(input_tensor=self.images, target_category=self.target_category) |
|
|
|
module.train() |
|
|
|
if len(self.images.shape) == 5: |
|
|
|
ld_res = grad_cams.shape[-1] |
|
img_res = self.images.shape[-1] |
|
img_slices = torch.linspace(int(img_res/ld_res/2), img_res-int(img_res/ld_res/2), ld_res, dtype=torch.long) |
|
|
|
|
|
grad_cams_image = _strip_image_from_grid_row( |
|
torch.stack([ |
|
torch.stack([ |
|
torch.tensor( |
|
_show_cam_on_image((self.images[i, 0, ..., img_slices[s]]).unsqueeze(-1).repeat(1, 1, 3).numpy(), grad_cams[i, ..., s], use_rgb=True) |
|
) |
|
for s in range(grad_cams.shape[-1])]) |
|
for i in range(self.num_images if self.num_images < grad_cams.shape[0] else grad_cams.shape[0])]) |
|
) |
|
|
|
elif len(self.images.shape) == 4: |
|
|
|
grad_cams_image = _strip_image_from_grid_row( |
|
torch.stack([ |
|
torch.stack([ |
|
torch.tensor( |
|
_show_cam_on_image((self.images[i, 0, ...]).unsqueeze(-1).repeat(1, 1, 3).numpy(), grad_cams[i, ...], use_rgb=True) |
|
) |
|
]) |
|
for i in range(self.num_images if self.num_images < grad_cams.shape[0] else grad_cams.shape[0])]) |
|
) |
|
|
|
else: |
|
raise RuntimeError("Attempting to build Grad-CAMs for data that is neither 2D nor 3D") |
|
|
|
module.logger.experiment.log({ |
|
"val/grad_cam": wandb.Image(grad_cams_image) |
|
}) |
|
|
|
def reset(self): |
|
self.targets = torch.zeros(self.num_images) |
|
self.activations = torch.zeros(self.num_images) |
|
self.images = torch.zeros(torch.Size([self.num_images]) + torch.Size(self.image_shape)) |