TSAIGradcam / utils.py
ibrim's picture
Upload 8 files
71c714a verified
raw
history blame
No virus
7.5 kB
#!/usr/bin/env python3
"""
Utility Script containing functions to be used for training
Author: Shilpaj Bhalerao
"""
# Standard Library Imports
import math
from typing import NoReturn
import io
from PIL import Image
# Third-Party Imports
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchsummary import summary
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
def get_summary(model, input_size: tuple) -> NoReturn:
"""
Function to get the summary of the model architecture
:param model: Object of model architecture class
:param input_size: Input data shape (Channels, Height, Width)
"""
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
network = model.to(device)
summary(network, input_size=input_size)
def get_misclassified_data(model, device, test_loader):
"""
Function to run the model on test set and return misclassified images
:param model: Network Architecture
:param device: CPU/GPU
:param test_loader: DataLoader for test set
"""
# Prepare the model for evaluation i.e. drop the dropout layer
model.eval()
# List to store misclassified Images
misclassified_data = []
# Reset the gradients
with torch.no_grad():
# Extract images, labels in a batch
for data, target in test_loader:
# Migrate the data to the device
data, target = data.to(device), target.to(device)
# Extract single image, label from the batch
for image, label in zip(data, target):
# Add batch dimension to the image
image = image.unsqueeze(0)
# Get the model prediction on the image
output = model(image)
# Convert the output from one-hot encoding to a value
pred = output.argmax(dim=1, keepdim=True)
# If prediction is incorrect, append the data
if pred != label:
misclassified_data.append((image, label, pred))
return misclassified_data
# -------------------- DATA STATISTICS --------------------
def get_mnist_statistics(data_set, data_set_type='Train'):
"""
Function to return the statistics of the training data
:param data_set: Training dataset
:param data_set_type: Type of dataset [Train/Test/Val]
"""
# We'd need to convert it into Numpy! Remember above we have converted it into tensors already
train_data = data_set.train_data
train_data = data_set.transform(train_data.numpy())
print(f'[{data_set_type}]')
print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
print(' - Tensor Shape:', data_set.train_data.size())
print(' - min:', torch.min(train_data))
print(' - max:', torch.max(train_data))
print(' - mean:', torch.mean(train_data))
print(' - std:', torch.std(train_data))
print(' - var:', torch.var(train_data))
dataiter = next(iter(data_set))
images, labels = dataiter[0], dataiter[1]
print(images.shape)
print(labels)
# Let's visualize some of the images
plt.imshow(images[0].numpy().squeeze(), cmap='gray')
def get_cifar_property(images, operation):
"""
Get the property on each channel of the CIFAR
:param images: Get the property value on the images
:param operation: Mean, std, Variance, etc
"""
param_r = eval('images[:, 0, :, :].' + operation + '()')
param_g = eval('images[:, 1, :, :].' + operation + '()')
param_b = eval('images[:, 2, :, :].' + operation + '()')
return param_r, param_g, param_b
def get_cifar_statistics(data_set, data_set_type='Train'):
"""
Function to get the statistical information of the CIFAR dataset
:param data_set: Training set of CIFAR
:param data_set_type: Training or Test data
"""
# Images in the dataset
images = [item[0] for item in data_set]
images = torch.stack(images, dim=0).numpy()
# Calculate mean over each channel
mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
# Calculate Standard deviation over each channel
std_r, std_g, std_b = get_cifar_property(images, 'std')
# Calculate min value over each channel
min_r, min_g, min_b = get_cifar_property(images, 'min')
# Calculate max value over each channel
max_r, max_g, max_b = get_cifar_property(images, 'max')
# Calculate variance value over each channel
var_r, var_g, var_b = get_cifar_property(images, 'var')
print(f'[{data_set_type}]')
print(f' - Total {data_set_type} Images: {len(data_set)}')
print(f' - Tensor Shape: {images[0].shape}')
print(f' - min: {min_r, min_g, min_b}')
print(f' - max: {max_r, max_g, max_b}')
print(f' - mean: {mean_r, mean_g, mean_b}')
print(f' - std: {std_r, std_g, std_b}')
print(f' - var: {var_r, var_g, var_b}')
# Let's visualize some of the images
plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
# -------------------- GradCam --------------------
def display_gradcam_output(data: list,
classes,
inv_normalize: transforms.Normalize,
model,
target_layers,
targets=None,
number_of_samples: int = 10,
transparency: float = 0.60):
"""
Function to visualize GradCam output on the data
:param data: List[Tuple(image, label)]
:param classes: Name of classes in the dataset
:param inv_normalize: Mean and Standard deviation values of the dataset
:param model: Model architecture
:param target_layers: Layers on which GradCam should be executed
:param targets: Classes to be focused on for GradCam
:param number_of_samples: Number of images to print
:param transparency: Weight of Normal image when mixed with activations
"""
# Plot configuration
fig = plt.figure(figsize=(10, 10))
x_count = 5
y_count = math.ceil(number_of_samples / x_count)
# Create an object for GradCam
cam = GradCAM(model=model, target_layers=target_layers)
# Iterate over number of specified images
for i in range(number_of_samples):
plt.subplot(y_count, x_count, i + 1)
input_tensor = data[i][0]
# Get the activations of the layer for the images
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# Get back the original image
img = input_tensor.squeeze(0).to('cpu')
img = inv_normalize(img)
rgb_img = np.transpose(img, (1, 2, 0))
rgb_img = rgb_img.numpy().astype(np.float32)
# Ensure the image data is within the [0, 1] range
rgb_img = np.clip(rgb_img, 0, 1)
# Mix the activations on the original image
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
# Display the images on the plot
plt.imshow(visualization)
plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
plt.xticks([])
plt.yticks([])
plt.tight_layout()
# Save the entire figure to a BytesIO object
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img_var = Image.open(buf)
return img_var