|
import os |
|
from tqdm import tqdm |
|
import argparse |
|
import cv2 |
|
import numpy as np |
|
from torchvision import transforms |
|
from datasets import Dataset, concatenate_datasets |
|
from pytorch_grad_cam import GradCAM |
|
|
|
from timm.models import create_model, load_checkpoint |
|
from timm.data import create_transform |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
|
|
|
|
if not os.path.isdir('results/grad_cam/correct'): |
|
os.mkdir('results/grad_cam/correct') |
|
if not os.path.isdir('results/grad_cam/incorrect'): |
|
os.mkdir('results/grad_cam/incorrect') |
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--model', default='tpmlp_s', type=str, metavar='MODEL', |
|
help='path to latest checkpoint (default: none)') |
|
parser.add_argument('--checkpoint', default='/home/daa5724/tpmlp-s-300-ema/last.pth.tar', type=str, metavar='PATH', |
|
help='path to latest checkpoint (default: none)') |
|
parser.add_argument('--idx', default='[1374, 27826, 14327, 1828, 31787, 21083, 38902, 7912, 10089, 16915, 20986, 35716, 15233, 20648, 30566, 20150, 45538, 42359, 39683, 20329, 20868, 48557, 10569, 37167, 11163, 6688, 21910, 44528, 10660, 13919, 10098, 46981, 36560, 14231, 45372, 6262, 23684, 16895, 17036, 15670, 35393, 26758, 18572, 48064, 29773, 25437, 5494, 12825, 25737, 45244, 16877, 29958, 38519, 5338, 46210, 15154, 15040, 15783, 13640, 14420, 26836, 38155, 45094, 33282, 13362, 42975, 38779, 24298, 20632, 48373, 28662, 21869, 37940, 25953, 29360, 9428, 22352, 6498, 2014, 9666, 30364, 21129, 43259, 16148, 31559, 4508, 42773, 8180, 17194, 46614, 23580, 3039, 36980, 35809, 860, 35940, 9670, 33552, 35731, 23777, 15272, 47792, 20589, 12044, 24154, 24852, 2090, 16158, 12333, 4109, 7612, 22611, 12808, 38787, 41688, 23714, 17498, 29326, 12237, 28137, 38521, 24060, 31545, 46094, 34674, 18182, 28380, 34046]', type=str, metavar='IDX', |
|
help='list of indices to use (default: [...]') |
|
parser.add_argument('--use-cuda', action='store_true', default=False, |
|
help='Use NVIDIA GPU acceleration') |
|
parser.add_argument('--aug_smooth', action='store_true', |
|
help='Apply test time augmentation to smooth the CAM') |
|
parser.add_argument( |
|
'--eigen_smooth', |
|
action='store_true', |
|
help='Reduce noise by taking the first principle componenet' |
|
'of cam_weights*activations') |
|
|
|
args = parser.parse_args() |
|
args.use_cuda = True |
|
if args.use_cuda: |
|
print('Using GPU for acceleration') |
|
else: |
|
print('Using CPU for computation') |
|
|
|
return args |
|
|
|
|
|
if __name__ == '__main__': |
|
args = get_args() |
|
|
|
model = create_model( |
|
args.model, |
|
num_classes=1000, |
|
in_chans=3, |
|
) |
|
load_checkpoint(model, args.checkpoint, True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
target_layers = [model.layers[3]] |
|
|
|
dataset = concatenate_datasets([Dataset.from_file(f"../../imagenet-1k/imagenet-1k-validation-{i:05d}-of-00013.arrow",) for i in range(13)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
augs = create_transform( |
|
input_size=(3, 224, 224), |
|
is_training=False, |
|
use_prefetcher=False, |
|
crop_pct=0.9, |
|
) |
|
resize = transforms.Compose(augs.transforms[:-1]) |
|
normalize = augs.transforms[-1] |
|
def transform(img): |
|
img = resize(img.convert("RGB")) |
|
tensor = normalize(img) |
|
return img, tensor[None] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idx = eval(args.idx) |
|
correct_idx = idx[:len(idx) // 2] |
|
incorrect_idx = idx[len(idx) // 2:] |
|
|
|
|
|
for idx in tqdm(correct_idx): |
|
data = dataset[int(idx)] |
|
image, label = data['image'], data['label'] |
|
rgb_img, input_tensor = transform(image) |
|
rgb_img = rgb_img.permute(1, 2, 0) |
|
input_tensor = input_tensor.cuda() |
|
|
|
targets = [ClassifierOutputTarget(label)] |
|
|
|
with GradCAM(model=model, |
|
target_layers=target_layers, |
|
use_cuda=True) as cam: |
|
|
|
grayscale_cam, pred = cam(input_tensor=input_tensor, |
|
targets=targets, |
|
aug_smooth=args.aug_smooth, |
|
eigen_smooth=args.eigen_smooth) |
|
|
|
if pred[0] != label: |
|
print(f"`pred != gdth` in correct_idx: {pred[0]} != {label}. Skipping idx {idx}.") |
|
|
|
|
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) |
|
|
|
|
|
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) |
|
rbg_image = cv2.cvtColor((rgb_img * 255).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(f'results/grad_cam/correct/grad_cam_{idx}.png', cam_image) |
|
cv2.imwrite(f'results/grad_cam/correct/image_{idx}[{label}].png', rbg_image) |
|
|
|
|
|
for idx in tqdm(incorrect_idx): |
|
data = dataset[int(idx)] |
|
image, label = data['image'], data['label'] |
|
rgb_img, input_tensor = transform(image) |
|
rgb_img = rgb_img.permute(1, 2, 0) |
|
input_tensor = input_tensor.cuda() |
|
|
|
targets = [ClassifierOutputTarget(label)] |
|
|
|
with GradCAM(model=model, |
|
target_layers=target_layers, |
|
use_cuda=True) as cam: |
|
|
|
grayscale_cam, pred = cam(input_tensor=input_tensor, |
|
targets=targets, |
|
aug_smooth=args.aug_smooth, |
|
eigen_smooth=args.eigen_smooth) |
|
|
|
if pred[0] == label: |
|
print(f"`pred == gdth` in incorrect_idx: {pred[0]} == {label}. Skipping idx {idx}.") |
|
|
|
|
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) |
|
|
|
|
|
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) |
|
rbg_image = cv2.cvtColor((rgb_img * 255).numpy().astype(np.uint8), cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(f'results/grad_cam/incorrect/grad_cam_gdth_{idx}.png', cam_image) |
|
cv2.imwrite(f'results/grad_cam/incorrect/image_{idx}[{label}].png', rbg_image) |
|
|
|
with GradCAM(model=model, |
|
target_layers=target_layers, |
|
use_cuda=True) as cam: |
|
|
|
grayscale_cam, pred = cam(input_tensor=input_tensor, |
|
targets=None, |
|
aug_smooth=args.aug_smooth, |
|
eigen_smooth=args.eigen_smooth) |
|
|
|
|
|
grayscale_cam = grayscale_cam[0, :] |
|
|
|
cam_image = show_cam_on_image(rgb_img.detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) |
|
|
|
|
|
cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(f'results/grad_cam/incorrect/grad_cam_pred_{idx}[{pred[0]}].png', cam_image) |
|
|