Shilpaj commited on
Commit
01f7053
·
1 Parent(s): 95a89e1

Feat: Utils and dataset file for the app

Browse files
Files changed (2) hide show
  1. datasets.py +38 -0
  2. utils.py +201 -0
datasets.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module containing wrapper classes for PyTorch Datasets
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 25, 2023
6
+ """
7
+ # Standard Library Imports
8
+ from typing import Tuple
9
+
10
+ # Third-Party Imports
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+ class AlbumDataset(datasets.CIFAR10):
15
+ """
16
+ Wrapper class to use albumentations library with PyTorch Dataset
17
+ """
18
+ def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
19
+ """
20
+ Constructor
21
+ :param root: Directory at which data is stored
22
+ :param train: Param to distinguish if data is training or test
23
+ :param download: Param to download the dataset from source
24
+ :param transform: List of transformation to be performed on the dataset
25
+ """
26
+ super().__init__(root=root, train=train, download=download, transform=transform)
27
+
28
+ def __getitem__(self, index: int) -> Tuple:
29
+ """
30
+ Method to return image and its label
31
+ :param index: Index of image and label in the dataset
32
+ """
33
+ image, label = self.data[index], self.targets[index]
34
+
35
+ if self.transform:
36
+ transformed = self.transform(image=image)
37
+ image = transformed["image"]
38
+ return image, label
utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility Script containing functions to be used for training
4
+ Author: Shilpaj Bhalerao
5
+ """
6
+ # Standard Library Imports
7
+ import math
8
+ from typing import NoReturn
9
+
10
+ # Third-Party Imports
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ import torch
14
+ from torchsummary import summary
15
+ from torchvision import transforms
16
+ from pytorch_grad_cam import GradCAM
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+
19
+
20
+ def get_summary(model: 'object of model architecture', input_size: tuple) -> NoReturn:
21
+ """
22
+ Function to get the summary of the model architecture
23
+ :param model: Object of model architecture class
24
+ :param input_size: Input data shape (Channels, Height, Width)
25
+ """
26
+ use_cuda = torch.cuda.is_available()
27
+ device = torch.device("cuda" if use_cuda else "cpu")
28
+ network = model.to(device)
29
+ summary(network, input_size=input_size)
30
+
31
+
32
+ def get_misclassified_data(model, device, test_loader):
33
+ """
34
+ Function to run the model on test set and return misclassified images
35
+ :param model: Network Architecture
36
+ :param device: CPU/GPU
37
+ :param test_loader: DataLoader for test set
38
+ """
39
+ # Prepare the model for evaluation i.e. drop the dropout layer
40
+ model.eval()
41
+
42
+ # List to store misclassified Images
43
+ misclassified_data = []
44
+
45
+ # Reset the gradients
46
+ with torch.no_grad():
47
+ # Extract images, labels in a batch
48
+ for data, target in test_loader:
49
+
50
+ # Migrate the data to the device
51
+ data, target = data.to(device), target.to(device)
52
+
53
+ # Extract single image, label from the batch
54
+ for image, label in zip(data, target):
55
+
56
+ # Add batch dimension to the image
57
+ image = image.unsqueeze(0)
58
+
59
+ # Get the model prediction on the image
60
+ output = model(image)
61
+
62
+ # Convert the output from one-hot encoding to a value
63
+ pred = output.argmax(dim=1, keepdim=True)
64
+
65
+ # If prediction is incorrect, append the data
66
+ if pred != label:
67
+ misclassified_data.append((image, label, pred))
68
+ return misclassified_data
69
+
70
+
71
+ # -------------------- DATA STATISTICS --------------------
72
+ def get_mnist_statistics(data_set, data_set_type='Train'):
73
+ """
74
+ Function to return the statistics of the training data
75
+ :param data_set: Training dataset
76
+ :param data_set_type: Type of dataset [Train/Test/Val]
77
+ """
78
+ # We'd need to convert it into Numpy! Remember above we have converted it into tensors already
79
+ train_data = data_set.train_data
80
+ train_data = data_set.transform(train_data.numpy())
81
+
82
+ print(f'[{data_set_type}]')
83
+ print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
84
+ print(' - Tensor Shape:', data_set.train_data.size())
85
+ print(' - min:', torch.min(train_data))
86
+ print(' - max:', torch.max(train_data))
87
+ print(' - mean:', torch.mean(train_data))
88
+ print(' - std:', torch.std(train_data))
89
+ print(' - var:', torch.var(train_data))
90
+
91
+ dataiter = next(iter(data_set))
92
+ images, labels = dataiter[0], dataiter[1]
93
+
94
+ print(images.shape)
95
+ print(labels)
96
+
97
+ # Let's visualize some of the images
98
+ plt.imshow(images[0].numpy().squeeze(), cmap='gray')
99
+
100
+
101
+ def get_cifar_property(images, operation):
102
+ """
103
+ Get the property on each channel of the CIFAR
104
+ :param images: Get the property value on the images
105
+ :param operation: Mean, std, Variance, etc
106
+ """
107
+ param_r = eval('images[:, 0, :, :].' + operation + '()')
108
+ param_g = eval('images[:, 1, :, :].' + operation + '()')
109
+ param_b = eval('images[:, 2, :, :].' + operation + '()')
110
+ return param_r, param_g, param_b
111
+
112
+
113
+ def get_cifar_statistics(data_set, data_set_type='Train'):
114
+ """
115
+ Function to get the statistical information of the CIFAR dataset
116
+ :param data_set: Training set of CIFAR
117
+ :param data_set_type: Training or Test data
118
+ """
119
+ # Images in the dataset
120
+ images = [item[0] for item in data_set]
121
+ images = torch.stack(images, dim=0).numpy()
122
+
123
+ # Calculate mean over each channel
124
+ mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
125
+
126
+ # Calculate Standard deviation over each channel
127
+ std_r, std_g, std_b = get_cifar_property(images, 'std')
128
+
129
+ # Calculate min value over each channel
130
+ min_r, min_g, min_b = get_cifar_property(images, 'min')
131
+
132
+ # Calculate max value over each channel
133
+ max_r, max_g, max_b = get_cifar_property(images, 'max')
134
+
135
+ # Calculate variance value over each channel
136
+ var_r, var_g, var_b = get_cifar_property(images, 'var')
137
+
138
+ print(f'[{data_set_type}]')
139
+ print(f' - Total {data_set_type} Images: {len(data_set)}')
140
+ print(f' - Tensor Shape: {images[0].shape}')
141
+ print(f' - min: {min_r, min_g, min_b}')
142
+ print(f' - max: {max_r, max_g, max_b}')
143
+ print(f' - mean: {mean_r, mean_g, mean_b}')
144
+ print(f' - std: {std_r, std_g, std_b}')
145
+ print(f' - var: {var_r, var_g, var_b}')
146
+
147
+ # Let's visualize some of the images
148
+ plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
149
+
150
+
151
+ # -------------------- GradCam --------------------
152
+ def display_gradcam_output(data: list,
153
+ classes: list[str],
154
+ inv_normalize: transforms.Normalize,
155
+ model: 'DL Model',
156
+ target_layers: list['model_layer'],
157
+ targets=None,
158
+ number_of_samples: int = 10,
159
+ transparency: float = 0.60):
160
+ """
161
+ Function to visualize GradCam output on the data
162
+ :param data: List[Tuple(image, label)]
163
+ :param classes: Name of classes in the dataset
164
+ :param inv_normalize: Mean and Standard deviation values of the dataset
165
+ :param model: Model architecture
166
+ :param target_layers: Layers on which GradCam should be executed
167
+ :param targets: Classes to be focused on for GradCam
168
+ :param number_of_samples: Number of images to print
169
+ :param transparency: Weight of Normal image when mixed with activations
170
+ """
171
+ # Plot configuration
172
+ fig = plt.figure(figsize=(10, 10))
173
+ x_count = 5
174
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
175
+
176
+ # Create an object for GradCam
177
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
178
+
179
+ # Iterate over number of specified images
180
+ for i in range(number_of_samples):
181
+ plt.subplot(y_count, x_count, i + 1)
182
+ input_tensor = data[i][0]
183
+
184
+ # Get the activations of the layer for the images
185
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
186
+ grayscale_cam = grayscale_cam[0, :]
187
+
188
+ # Get back the original image
189
+ img = input_tensor.squeeze(0).to('cpu')
190
+ img = inv_normalize(img)
191
+ rgb_img = np.transpose(img, (1, 2, 0))
192
+ rgb_img = rgb_img.numpy()
193
+
194
+ # Mix the activations on the original image
195
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
196
+
197
+ # Display the images on the plot
198
+ plt.imshow(visualization)
199
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
200
+ plt.xticks([])
201
+ plt.yticks([])