Spaces:
Runtime error
Runtime error
import os | |
import math | |
import numpy as np | |
import pandas as pd | |
import seaborn as sn | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | |
from pytorch_lightning import LightningModule, Trainer, seed_everything | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
from pytorch_lightning.loggers import CSVLogger | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.optim.swa_utils import AveragedModel, update_bn | |
from torchmetrics.functional import accuracy | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from torchvision import datasets, transforms, utils | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
targets = None | |
# Yes - This is important predecessor3 for gradioMisClassGradCAM | |
def display_gradcam_output(data: list, | |
classes: list[str], | |
inv_normalize: transforms.Normalize, | |
model: 'DL Model', | |
target_layers: list['model_layer'], | |
targets=None, | |
number_of_samples: int = 10, | |
transparency: float = 0.60): | |
""" | |
Function to visualize GradCam output on the data | |
:param data: List[Tuple(image, label)] | |
:param classes: Name of classes in the dataset | |
:param inv_normalize: Mean and Standard deviation values of the dataset | |
:param model: Model architecture | |
:param target_layers: Layers on which GradCam should be executed | |
:param targets: Classes to be focused on for GradCam | |
:param number_of_samples: Number of images to print | |
:param transparency: Weight of Normal image when mixed with activations | |
""" | |
# Plot configuration | |
fig = plt.figure(figsize=(10, 10)) | |
x_count = 5 | |
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
# Create an object for GradCam | |
#cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True) | |
cam = GradCAM(model=model, target_layers=target_layers) | |
# Iterate over number of specified images | |
for i in range(number_of_samples): | |
plt.subplot(y_count, x_count, i + 1) | |
input_tensor = data[i][0] | |
# Get the activations of the layer for the images | |
grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
grayscale_cam = grayscale_cam[0, :] | |
# Get back the original image | |
img = input_tensor.squeeze(0).to('cpu') | |
img = inv_normalize(img) | |
rgb_img = np.transpose(img, (1, 2, 0)) | |
rgb_img = rgb_img.numpy() | |
# Mix the activations on the original image | |
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency) | |
# Display the images on the plot | |
plt.imshow(visualization) | |
# plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()]) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.savefig('imshow_output_gradcam.png') | |
return 'imshow_output_gradcam.png' |