Feat: Utils and dataset file for the app
Browse files- datasets.py +38 -0
- utils.py +201 -0
datasets.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Module containing wrapper classes for PyTorch Datasets
|
4 |
+
Author: Shilpaj Bhalerao
|
5 |
+
Date: Jun 25, 2023
|
6 |
+
"""
|
7 |
+
# Standard Library Imports
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
# Third-Party Imports
|
11 |
+
from torchvision import datasets, transforms
|
12 |
+
|
13 |
+
|
14 |
+
class AlbumDataset(datasets.CIFAR10):
|
15 |
+
"""
|
16 |
+
Wrapper class to use albumentations library with PyTorch Dataset
|
17 |
+
"""
|
18 |
+
def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
|
19 |
+
"""
|
20 |
+
Constructor
|
21 |
+
:param root: Directory at which data is stored
|
22 |
+
:param train: Param to distinguish if data is training or test
|
23 |
+
:param download: Param to download the dataset from source
|
24 |
+
:param transform: List of transformation to be performed on the dataset
|
25 |
+
"""
|
26 |
+
super().__init__(root=root, train=train, download=download, transform=transform)
|
27 |
+
|
28 |
+
def __getitem__(self, index: int) -> Tuple:
|
29 |
+
"""
|
30 |
+
Method to return image and its label
|
31 |
+
:param index: Index of image and label in the dataset
|
32 |
+
"""
|
33 |
+
image, label = self.data[index], self.targets[index]
|
34 |
+
|
35 |
+
if self.transform:
|
36 |
+
transformed = self.transform(image=image)
|
37 |
+
image = transformed["image"]
|
38 |
+
return image, label
|
utils.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Utility Script containing functions to be used for training
|
4 |
+
Author: Shilpaj Bhalerao
|
5 |
+
"""
|
6 |
+
# Standard Library Imports
|
7 |
+
import math
|
8 |
+
from typing import NoReturn
|
9 |
+
|
10 |
+
# Third-Party Imports
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import torch
|
14 |
+
from torchsummary import summary
|
15 |
+
from torchvision import transforms
|
16 |
+
from pytorch_grad_cam import GradCAM
|
17 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
18 |
+
|
19 |
+
|
20 |
+
def get_summary(model: 'object of model architecture', input_size: tuple) -> NoReturn:
|
21 |
+
"""
|
22 |
+
Function to get the summary of the model architecture
|
23 |
+
:param model: Object of model architecture class
|
24 |
+
:param input_size: Input data shape (Channels, Height, Width)
|
25 |
+
"""
|
26 |
+
use_cuda = torch.cuda.is_available()
|
27 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
28 |
+
network = model.to(device)
|
29 |
+
summary(network, input_size=input_size)
|
30 |
+
|
31 |
+
|
32 |
+
def get_misclassified_data(model, device, test_loader):
|
33 |
+
"""
|
34 |
+
Function to run the model on test set and return misclassified images
|
35 |
+
:param model: Network Architecture
|
36 |
+
:param device: CPU/GPU
|
37 |
+
:param test_loader: DataLoader for test set
|
38 |
+
"""
|
39 |
+
# Prepare the model for evaluation i.e. drop the dropout layer
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
# List to store misclassified Images
|
43 |
+
misclassified_data = []
|
44 |
+
|
45 |
+
# Reset the gradients
|
46 |
+
with torch.no_grad():
|
47 |
+
# Extract images, labels in a batch
|
48 |
+
for data, target in test_loader:
|
49 |
+
|
50 |
+
# Migrate the data to the device
|
51 |
+
data, target = data.to(device), target.to(device)
|
52 |
+
|
53 |
+
# Extract single image, label from the batch
|
54 |
+
for image, label in zip(data, target):
|
55 |
+
|
56 |
+
# Add batch dimension to the image
|
57 |
+
image = image.unsqueeze(0)
|
58 |
+
|
59 |
+
# Get the model prediction on the image
|
60 |
+
output = model(image)
|
61 |
+
|
62 |
+
# Convert the output from one-hot encoding to a value
|
63 |
+
pred = output.argmax(dim=1, keepdim=True)
|
64 |
+
|
65 |
+
# If prediction is incorrect, append the data
|
66 |
+
if pred != label:
|
67 |
+
misclassified_data.append((image, label, pred))
|
68 |
+
return misclassified_data
|
69 |
+
|
70 |
+
|
71 |
+
# -------------------- DATA STATISTICS --------------------
|
72 |
+
def get_mnist_statistics(data_set, data_set_type='Train'):
|
73 |
+
"""
|
74 |
+
Function to return the statistics of the training data
|
75 |
+
:param data_set: Training dataset
|
76 |
+
:param data_set_type: Type of dataset [Train/Test/Val]
|
77 |
+
"""
|
78 |
+
# We'd need to convert it into Numpy! Remember above we have converted it into tensors already
|
79 |
+
train_data = data_set.train_data
|
80 |
+
train_data = data_set.transform(train_data.numpy())
|
81 |
+
|
82 |
+
print(f'[{data_set_type}]')
|
83 |
+
print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
|
84 |
+
print(' - Tensor Shape:', data_set.train_data.size())
|
85 |
+
print(' - min:', torch.min(train_data))
|
86 |
+
print(' - max:', torch.max(train_data))
|
87 |
+
print(' - mean:', torch.mean(train_data))
|
88 |
+
print(' - std:', torch.std(train_data))
|
89 |
+
print(' - var:', torch.var(train_data))
|
90 |
+
|
91 |
+
dataiter = next(iter(data_set))
|
92 |
+
images, labels = dataiter[0], dataiter[1]
|
93 |
+
|
94 |
+
print(images.shape)
|
95 |
+
print(labels)
|
96 |
+
|
97 |
+
# Let's visualize some of the images
|
98 |
+
plt.imshow(images[0].numpy().squeeze(), cmap='gray')
|
99 |
+
|
100 |
+
|
101 |
+
def get_cifar_property(images, operation):
|
102 |
+
"""
|
103 |
+
Get the property on each channel of the CIFAR
|
104 |
+
:param images: Get the property value on the images
|
105 |
+
:param operation: Mean, std, Variance, etc
|
106 |
+
"""
|
107 |
+
param_r = eval('images[:, 0, :, :].' + operation + '()')
|
108 |
+
param_g = eval('images[:, 1, :, :].' + operation + '()')
|
109 |
+
param_b = eval('images[:, 2, :, :].' + operation + '()')
|
110 |
+
return param_r, param_g, param_b
|
111 |
+
|
112 |
+
|
113 |
+
def get_cifar_statistics(data_set, data_set_type='Train'):
|
114 |
+
"""
|
115 |
+
Function to get the statistical information of the CIFAR dataset
|
116 |
+
:param data_set: Training set of CIFAR
|
117 |
+
:param data_set_type: Training or Test data
|
118 |
+
"""
|
119 |
+
# Images in the dataset
|
120 |
+
images = [item[0] for item in data_set]
|
121 |
+
images = torch.stack(images, dim=0).numpy()
|
122 |
+
|
123 |
+
# Calculate mean over each channel
|
124 |
+
mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
|
125 |
+
|
126 |
+
# Calculate Standard deviation over each channel
|
127 |
+
std_r, std_g, std_b = get_cifar_property(images, 'std')
|
128 |
+
|
129 |
+
# Calculate min value over each channel
|
130 |
+
min_r, min_g, min_b = get_cifar_property(images, 'min')
|
131 |
+
|
132 |
+
# Calculate max value over each channel
|
133 |
+
max_r, max_g, max_b = get_cifar_property(images, 'max')
|
134 |
+
|
135 |
+
# Calculate variance value over each channel
|
136 |
+
var_r, var_g, var_b = get_cifar_property(images, 'var')
|
137 |
+
|
138 |
+
print(f'[{data_set_type}]')
|
139 |
+
print(f' - Total {data_set_type} Images: {len(data_set)}')
|
140 |
+
print(f' - Tensor Shape: {images[0].shape}')
|
141 |
+
print(f' - min: {min_r, min_g, min_b}')
|
142 |
+
print(f' - max: {max_r, max_g, max_b}')
|
143 |
+
print(f' - mean: {mean_r, mean_g, mean_b}')
|
144 |
+
print(f' - std: {std_r, std_g, std_b}')
|
145 |
+
print(f' - var: {var_r, var_g, var_b}')
|
146 |
+
|
147 |
+
# Let's visualize some of the images
|
148 |
+
plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
|
149 |
+
|
150 |
+
|
151 |
+
# -------------------- GradCam --------------------
|
152 |
+
def display_gradcam_output(data: list,
|
153 |
+
classes: list[str],
|
154 |
+
inv_normalize: transforms.Normalize,
|
155 |
+
model: 'DL Model',
|
156 |
+
target_layers: list['model_layer'],
|
157 |
+
targets=None,
|
158 |
+
number_of_samples: int = 10,
|
159 |
+
transparency: float = 0.60):
|
160 |
+
"""
|
161 |
+
Function to visualize GradCam output on the data
|
162 |
+
:param data: List[Tuple(image, label)]
|
163 |
+
:param classes: Name of classes in the dataset
|
164 |
+
:param inv_normalize: Mean and Standard deviation values of the dataset
|
165 |
+
:param model: Model architecture
|
166 |
+
:param target_layers: Layers on which GradCam should be executed
|
167 |
+
:param targets: Classes to be focused on for GradCam
|
168 |
+
:param number_of_samples: Number of images to print
|
169 |
+
:param transparency: Weight of Normal image when mixed with activations
|
170 |
+
"""
|
171 |
+
# Plot configuration
|
172 |
+
fig = plt.figure(figsize=(10, 10))
|
173 |
+
x_count = 5
|
174 |
+
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
|
175 |
+
|
176 |
+
# Create an object for GradCam
|
177 |
+
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
|
178 |
+
|
179 |
+
# Iterate over number of specified images
|
180 |
+
for i in range(number_of_samples):
|
181 |
+
plt.subplot(y_count, x_count, i + 1)
|
182 |
+
input_tensor = data[i][0]
|
183 |
+
|
184 |
+
# Get the activations of the layer for the images
|
185 |
+
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
|
186 |
+
grayscale_cam = grayscale_cam[0, :]
|
187 |
+
|
188 |
+
# Get back the original image
|
189 |
+
img = input_tensor.squeeze(0).to('cpu')
|
190 |
+
img = inv_normalize(img)
|
191 |
+
rgb_img = np.transpose(img, (1, 2, 0))
|
192 |
+
rgb_img = rgb_img.numpy()
|
193 |
+
|
194 |
+
# Mix the activations on the original image
|
195 |
+
visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
|
196 |
+
|
197 |
+
# Display the images on the plot
|
198 |
+
plt.imshow(visualization)
|
199 |
+
plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
|
200 |
+
plt.xticks([])
|
201 |
+
plt.yticks([])
|