ibrim commited on
Commit
71c714a
1 Parent(s): 0d4900d

Upload 8 files

Browse files
Files changed (7) hide show
  1. augmentations.py +56 -0
  2. datasets.py +38 -0
  3. gradio_utils.py +151 -0
  4. resnet.py +76 -0
  5. training_utils.py +126 -0
  6. utils.py +216 -0
  7. visualize.py +396 -0
augmentations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jul 23, 2023
6
+ """
7
+ # Third-Party Imports
8
+ import torch
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+
12
+
13
+ # Train Phase transformations
14
+ train_set_transforms = {
15
+ 'randomcrop': A.RandomCrop(height=32, width=32, p=0.2),
16
+ 'horizontalflip': A.HorizontalFlip(),
17
+ 'cutout': A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
18
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
19
+ 'standardize': ToTensorV2(),
20
+ }
21
+
22
+ # Test Phase transformations
23
+ test_set_transforms = {
24
+ 'normalize': A.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
25
+ 'standardize': ToTensorV2()
26
+ }
27
+
28
+
29
+ class AddGaussianNoise(object):
30
+ """
31
+ Class for custom augmentation strategy
32
+ """
33
+ def __init__(self, mean=0., std=1.):
34
+ """
35
+ Constructor
36
+ """
37
+ self.std = std
38
+ self.mean = mean
39
+
40
+ def __call__(self, tensor):
41
+ """
42
+ Augmentation strategy to be implemented when called
43
+ """
44
+ return tensor + torch.randn(tensor.size()) * self.std + self.mean
45
+
46
+ def __repr__(self):
47
+ """
48
+ Method to print more infor about the strategy
49
+ """
50
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
51
+
52
+ # Usage details
53
+ # transforms = transforms.Compose([
54
+ # transforms.ToTensor(),
55
+ # AddGaussianNoise(0., 1.0),
56
+ # ])
datasets.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Module containing wrapper classes for PyTorch Datasets
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 25, 2023
6
+ """
7
+ # Standard Library Imports
8
+ from typing import Tuple
9
+
10
+ # Third-Party Imports
11
+ from torchvision import datasets, transforms
12
+
13
+
14
+ class AlbumDataset(datasets.CIFAR10):
15
+ """
16
+ Wrapper class to use albumentations library with PyTorch Dataset
17
+ """
18
+ def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
19
+ """
20
+ Constructor
21
+ :param root: Directory at which data is stored
22
+ :param train: Param to distinguish if data is training or test
23
+ :param download: Param to download the dataset from source
24
+ :param transform: List of transformation to be performed on the dataset
25
+ """
26
+ super().__init__(root=root, train=train, download=download, transform=transform)
27
+
28
+ def __getitem__(self, index: int) -> Tuple:
29
+ """
30
+ Method to return image and its label
31
+ :param index: Index of image and label in the dataset
32
+ """
33
+ image, label = self.data[index], self.targets[index]
34
+
35
+ if self.transform:
36
+ transformed = self.transform(image=image)
37
+ image = transformed["image"]
38
+ return image, label
gradio_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import all the required modules
2
+ import os
3
+ os.environ['KMP_DUPLICATE_LIB_OK']='True'
4
+ import math
5
+ from collections import OrderedDict
6
+ import sys
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.optim as optim
11
+ from torchvision import datasets, transforms
12
+ import albumentations as A
13
+ from albumentations.pytorch import ToTensorV2
14
+ from torch_lr_finder import LRFinder
15
+ from pytorch_grad_cam import GradCAM
16
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
17
+ from pytorch_grad_cam.utils.image import show_cam_on_image
18
+ from utils import display_gradcam_output
19
+ from torchmetrics import Accuracy
20
+ from torchvision.datasets import CIFAR10
21
+ import torch
22
+ from pytorch_lightning import LightningModule, Trainer
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+ from torch.utils.data import DataLoader, random_split
26
+ from torchmetrics import Accuracy
27
+ from torchvision import transforms
28
+ import matplotlib.pyplot as plt
29
+ import random
30
+ from resnet import *
31
+
32
+ PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
33
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
34
+ 'dog', 'frog', 'horse', 'ship', 'truck')
35
+
36
+
37
+
38
+ class LitCIFAR(LightningModule):
39
+ def __init__(self, data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4):
40
+
41
+ super().__init__()
42
+
43
+ # Set our init args as class attributes
44
+ self.data_dir = data_dir
45
+ self.hidden_size = hidden_size
46
+ self.learning_rate = learning_rate
47
+
48
+ # Hardcode some dataset specific attributes
49
+ self.num_classes = 10
50
+
51
+ self.misclassified_indices = []
52
+
53
+ # Define PyTorch model
54
+ self.model = ResNet18()
55
+
56
+ self.accuracy = Accuracy(num_classes=self.num_classes, task='multiclass')
57
+
58
+ def forward(self, x):
59
+ x = self.model(x)
60
+ x = x.view(-1, 10)
61
+ return F.log_softmax(x, dim=1)
62
+
63
+ def training_step(self, batch, batch_idx):
64
+ x, y = batch
65
+ logits = self(x)
66
+ loss = F.nll_loss(logits, y)
67
+ return loss
68
+
69
+ def validation_step(self, batch, batch_idx):
70
+ x, y = batch
71
+ logits = self(x)
72
+ loss = F.nll_loss(logits, y)
73
+ preds = torch.argmax(logits, dim=1)
74
+ self.accuracy(preds, y)
75
+ self.log("val_loss", loss, prog_bar=True)
76
+ self.log("val_acc", self.accuracy, prog_bar=True)
77
+ return loss
78
+
79
+ # def test_step(self, batch, batch_idx):
80
+ # # Here we just reuse the validation_step for testing
81
+ # return self.validation_step(batch, batch_idx)
82
+
83
+ def configure_optimizers(self):
84
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
85
+ return optimizer
86
+
87
+
88
+ modelfin = LitCIFAR()
89
+ lit_cifar_instance = LitCIFAR()
90
+ # Load the state dictionary from the checkpoint file
91
+ modelfin.load_state_dict(torch.load("submission_gradio/models/model.ckpt"))
92
+ # Set the model to evaluation mode
93
+ modelfin.eval()
94
+ # If you need to use the model on a GPU, move the model to GPU after loading the state dict and setting it to eval mode
95
+ modelfin = modelfin.to('cuda')
96
+
97
+ inv_normalize = transforms.Normalize(
98
+ mean=[-1.9899, -1.9844, -1.7111],
99
+ std=[4.0486, 4.1152, 3.8314]
100
+ )
101
+
102
+ from torch.utils.data import Dataset
103
+ import numpy as np
104
+ class CIFAR10Dataset(Dataset):
105
+ def __init__(self, data_dir, train=True, transform=None):
106
+ self.data = CIFAR10(data_dir, train=train, download=True, transform=None)
107
+ self.transform = transform
108
+
109
+ def __len__(self):
110
+ return len(self.data)
111
+
112
+ def __getitem__(self, idx):
113
+ image, label = self.data[idx]
114
+ if self.transform:
115
+ # Apply the transformation in the correct format
116
+ image = self.transform(image=np.array(image))['image']
117
+ return image, label
118
+
119
+ transform = A.Compose(
120
+ [
121
+ A.RandomCrop(height=32, width=32, p=0.2),
122
+ A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
123
+ A.HorizontalFlip(),
124
+ A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
125
+ ToTensorV2(),
126
+ ]
127
+ )
128
+ test_transform = A.Compose(
129
+ [
130
+ A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
131
+ ToTensorV2(),
132
+ ]
133
+ )
134
+
135
+ cifar_full = CIFAR10Dataset(PATH_DATASETS, train=True, transform=transform)
136
+ cifar_test = CIFAR10Dataset(PATH_DATASETS, train=False, transform=test_transform)
137
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
138
+ 'dog', 'frog', 'horse', 'ship', 'truck')
139
+
140
+ test_loader = DataLoader(cifar_test, batch_size=512, num_workers=os.cpu_count())
141
+
142
+ import gradio as gr
143
+ from utils import display_gradcam_output
144
+ from utils import get_misclassified_data
145
+ from visualize import display_cifar_misclassified_data
146
+ from PIL import Image
147
+
148
+ #misclassified_data, classes, inv_normalize, modelfin, target_layers, targets, number_of_samples=2, transparency=0.7
149
+
150
+
151
+
resnet.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ResNet in PyTorch.
3
+ For Pre-activation ResNet, see 'preact_resnet.py'.
4
+
5
+ Reference:
6
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
8
+ """
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class BasicBlock(nn.Module):
14
+ expansion = 1
15
+
16
+ def __init__(self, in_planes, planes, stride=1):
17
+ super(BasicBlock, self).__init__()
18
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19
+ self.bn1 = nn.BatchNorm2d(planes)
20
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.shortcut = nn.Sequential()
24
+ if stride != 1 or in_planes != self.expansion*planes:
25
+ self.shortcut = nn.Sequential(
26
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
27
+ nn.BatchNorm2d(self.expansion*planes)
28
+ )
29
+
30
+ def forward(self, x):
31
+ out = F.relu(self.bn1(self.conv1(x)))
32
+ out = self.bn2(self.conv2(out))
33
+ out += self.shortcut(x)
34
+ out = F.relu(out)
35
+ return out
36
+
37
+
38
+ class ResNet(nn.Module):
39
+ def __init__(self, block, num_blocks, num_classes=10):
40
+ super(ResNet, self).__init__()
41
+ self.in_planes = 64
42
+
43
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
44
+ self.bn1 = nn.BatchNorm2d(64)
45
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
46
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
47
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
48
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
49
+ self.linear = nn.Linear(512*block.expansion, num_classes)
50
+
51
+ def _make_layer(self, block, planes, num_blocks, stride):
52
+ strides = [stride] + [1]*(num_blocks-1)
53
+ layers = []
54
+ for stride in strides:
55
+ layers.append(block(self.in_planes, planes, stride))
56
+ self.in_planes = planes * block.expansion
57
+ return nn.Sequential(*layers)
58
+
59
+ def forward(self, x):
60
+ out = F.relu(self.bn1(self.conv1(x)))
61
+ out = self.layer1(out)
62
+ out = self.layer2(out)
63
+ out = self.layer3(out)
64
+ out = self.layer4(out)
65
+ out = F.avg_pool2d(out, 4)
66
+ out = out.view(out.size(0), -1)
67
+ out = self.linear(out)
68
+ return out
69
+
70
+
71
+ def ResNet18():
72
+ return ResNet(BasicBlock, [2, 2, 2, 2])
73
+
74
+
75
+ def ResNet34():
76
+ return ResNet(BasicBlock, [3, 4, 6, 3])
training_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utilities for Model Training
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+
9
+ # Third-Party Imports
10
+ from tqdm import tqdm
11
+ import torch
12
+
13
+
14
+ def get_correct_predictions(prediction, labels):
15
+ """
16
+ Function to return total number of correct predictions
17
+ :param prediction: Model predictions on a given sample of data
18
+ :param labels: Correct labels of a given sample of data
19
+ :return: Number of correct predictions
20
+ """
21
+ return prediction.argmax(dim=1).eq(labels).sum().item()
22
+
23
+
24
+ def train(model, device, train_loader, optimizer, criterion, scheduler=None):
25
+ """
26
+ Function to train model on the training dataset
27
+ :param model: Model architecture
28
+ :param device: Device on which training is to be done (GPU/CPU)
29
+ :param train_loader: DataLoader for training dataset
30
+ :param optimizer: Optimization algorithm to be used for updating weights
31
+ :param criterion: Loss function for training
32
+ :param scheduler: Scheduler for learning rate
33
+ """
34
+ # Enable layers like Dropout for model training
35
+ model.train()
36
+
37
+ # Utility to display training progress
38
+ pbar = tqdm(train_loader)
39
+
40
+ # Variables to track loss and accuracy during training
41
+ train_loss = 0
42
+ correct = 0
43
+ processed = 0
44
+
45
+ # Iterate over each batch and fetch images and labels from the batch
46
+ for batch_idx, (data, target) in enumerate(pbar):
47
+
48
+ # Put the images and labels on the selected device
49
+ data, target = data.to(device), target.to(device)
50
+
51
+ # Reset the gradients for each batch
52
+ optimizer.zero_grad()
53
+
54
+ # Predict
55
+ pred = model(data)
56
+
57
+ # Calculate loss
58
+ loss = criterion(pred, target)
59
+ train_loss += loss.item()
60
+
61
+ # Backpropagation
62
+ loss.backward()
63
+ optimizer.step()
64
+
65
+ # Use learning rate scheduler if defined
66
+ if scheduler:
67
+ scheduler.step()
68
+
69
+ # Get total number of correct predictions
70
+ correct += get_correct_predictions(pred, target)
71
+ processed += len(data)
72
+
73
+ # Display the training information
74
+ pbar.set_description(
75
+ desc=f'Train: Loss={loss.item():0.4f} Batch_id={batch_idx} Accuracy={100 * correct / processed:0.2f}')
76
+
77
+ return correct, processed, train_loss
78
+
79
+
80
+ def test(model, device, test_loader, criterion):
81
+ """
82
+ Function to test the model training progress on the test dataset
83
+ :param model: Model architecture
84
+ :param device: Device on which training is to be done (GPU/CPU)
85
+ :param test_loader: DataLoader for test dataset
86
+ :param criterion: Loss function for test dataset
87
+ """
88
+ # Disable layers like Dropout for model inference
89
+ model.eval()
90
+
91
+ # Variables to track loss and accuracy
92
+ test_loss = 0
93
+ correct = 0
94
+
95
+ # Disable gradient updation
96
+ with torch.no_grad():
97
+ # Iterate over each batch and fetch images and labels from the batch
98
+ for batch_idx, (data, target) in enumerate(test_loader):
99
+
100
+ # Put the images and labels on the selected device
101
+ data, target = data.to(device), target.to(device)
102
+
103
+ # Pass the images to the output and get the model predictions
104
+ output = model(data)
105
+ test_loss += criterion(output, target).item() # sum up batch loss
106
+
107
+ # Sum up batch correct predictions
108
+ correct += get_correct_predictions(output, target)
109
+
110
+ # Calculate test loss for a epoch
111
+ test_loss /= len(test_loader.dataset)
112
+
113
+ print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
114
+ test_loss, correct, len(test_loader.dataset),
115
+ 100. * correct / len(test_loader.dataset)))
116
+
117
+ return correct, test_loss
118
+
119
+
120
+ def get_lr(optimizer):
121
+ """
122
+ Function to track learning rate while model training
123
+ :param optimizer: Optimizer used for training
124
+ """
125
+ for param_group in optimizer.param_groups:
126
+ return param_group['lr']
utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility Script containing functions to be used for training
4
+ Author: Shilpaj Bhalerao
5
+ """
6
+ # Standard Library Imports
7
+ import math
8
+ from typing import NoReturn
9
+ import io
10
+ from PIL import Image
11
+ # Third-Party Imports
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import torch
15
+ from torchsummary import summary
16
+ from torchvision import transforms
17
+ from pytorch_grad_cam import GradCAM
18
+ from pytorch_grad_cam.utils.image import show_cam_on_image
19
+
20
+
21
+ def get_summary(model, input_size: tuple) -> NoReturn:
22
+ """
23
+ Function to get the summary of the model architecture
24
+ :param model: Object of model architecture class
25
+ :param input_size: Input data shape (Channels, Height, Width)
26
+ """
27
+ use_cuda = torch.cuda.is_available()
28
+ device = torch.device("cuda" if use_cuda else "cpu")
29
+ network = model.to(device)
30
+ summary(network, input_size=input_size)
31
+
32
+
33
+ def get_misclassified_data(model, device, test_loader):
34
+ """
35
+ Function to run the model on test set and return misclassified images
36
+ :param model: Network Architecture
37
+ :param device: CPU/GPU
38
+ :param test_loader: DataLoader for test set
39
+ """
40
+ # Prepare the model for evaluation i.e. drop the dropout layer
41
+ model.eval()
42
+
43
+ # List to store misclassified Images
44
+ misclassified_data = []
45
+
46
+ # Reset the gradients
47
+ with torch.no_grad():
48
+ # Extract images, labels in a batch
49
+ for data, target in test_loader:
50
+
51
+ # Migrate the data to the device
52
+ data, target = data.to(device), target.to(device)
53
+
54
+ # Extract single image, label from the batch
55
+ for image, label in zip(data, target):
56
+
57
+ # Add batch dimension to the image
58
+ image = image.unsqueeze(0)
59
+
60
+ # Get the model prediction on the image
61
+ output = model(image)
62
+
63
+ # Convert the output from one-hot encoding to a value
64
+ pred = output.argmax(dim=1, keepdim=True)
65
+
66
+ # If prediction is incorrect, append the data
67
+ if pred != label:
68
+ misclassified_data.append((image, label, pred))
69
+ return misclassified_data
70
+
71
+
72
+ # -------------------- DATA STATISTICS --------------------
73
+ def get_mnist_statistics(data_set, data_set_type='Train'):
74
+ """
75
+ Function to return the statistics of the training data
76
+ :param data_set: Training dataset
77
+ :param data_set_type: Type of dataset [Train/Test/Val]
78
+ """
79
+ # We'd need to convert it into Numpy! Remember above we have converted it into tensors already
80
+ train_data = data_set.train_data
81
+ train_data = data_set.transform(train_data.numpy())
82
+
83
+ print(f'[{data_set_type}]')
84
+ print(' - Numpy Shape:', data_set.train_data.cpu().numpy().shape)
85
+ print(' - Tensor Shape:', data_set.train_data.size())
86
+ print(' - min:', torch.min(train_data))
87
+ print(' - max:', torch.max(train_data))
88
+ print(' - mean:', torch.mean(train_data))
89
+ print(' - std:', torch.std(train_data))
90
+ print(' - var:', torch.var(train_data))
91
+
92
+ dataiter = next(iter(data_set))
93
+ images, labels = dataiter[0], dataiter[1]
94
+
95
+ print(images.shape)
96
+ print(labels)
97
+
98
+ # Let's visualize some of the images
99
+ plt.imshow(images[0].numpy().squeeze(), cmap='gray')
100
+
101
+
102
+ def get_cifar_property(images, operation):
103
+ """
104
+ Get the property on each channel of the CIFAR
105
+ :param images: Get the property value on the images
106
+ :param operation: Mean, std, Variance, etc
107
+ """
108
+ param_r = eval('images[:, 0, :, :].' + operation + '()')
109
+ param_g = eval('images[:, 1, :, :].' + operation + '()')
110
+ param_b = eval('images[:, 2, :, :].' + operation + '()')
111
+ return param_r, param_g, param_b
112
+
113
+
114
+ def get_cifar_statistics(data_set, data_set_type='Train'):
115
+ """
116
+ Function to get the statistical information of the CIFAR dataset
117
+ :param data_set: Training set of CIFAR
118
+ :param data_set_type: Training or Test data
119
+ """
120
+ # Images in the dataset
121
+ images = [item[0] for item in data_set]
122
+ images = torch.stack(images, dim=0).numpy()
123
+
124
+ # Calculate mean over each channel
125
+ mean_r, mean_g, mean_b = get_cifar_property(images, 'mean')
126
+
127
+ # Calculate Standard deviation over each channel
128
+ std_r, std_g, std_b = get_cifar_property(images, 'std')
129
+
130
+ # Calculate min value over each channel
131
+ min_r, min_g, min_b = get_cifar_property(images, 'min')
132
+
133
+ # Calculate max value over each channel
134
+ max_r, max_g, max_b = get_cifar_property(images, 'max')
135
+
136
+ # Calculate variance value over each channel
137
+ var_r, var_g, var_b = get_cifar_property(images, 'var')
138
+
139
+ print(f'[{data_set_type}]')
140
+ print(f' - Total {data_set_type} Images: {len(data_set)}')
141
+ print(f' - Tensor Shape: {images[0].shape}')
142
+ print(f' - min: {min_r, min_g, min_b}')
143
+ print(f' - max: {max_r, max_g, max_b}')
144
+ print(f' - mean: {mean_r, mean_g, mean_b}')
145
+ print(f' - std: {std_r, std_g, std_b}')
146
+ print(f' - var: {var_r, var_g, var_b}')
147
+
148
+ # Let's visualize some of the images
149
+ plt.imshow(np.transpose(images[1].squeeze(), (1, 2, 0)))
150
+
151
+
152
+ # -------------------- GradCam --------------------
153
+ def display_gradcam_output(data: list,
154
+ classes,
155
+ inv_normalize: transforms.Normalize,
156
+ model,
157
+ target_layers,
158
+ targets=None,
159
+ number_of_samples: int = 10,
160
+ transparency: float = 0.60):
161
+ """
162
+ Function to visualize GradCam output on the data
163
+ :param data: List[Tuple(image, label)]
164
+ :param classes: Name of classes in the dataset
165
+ :param inv_normalize: Mean and Standard deviation values of the dataset
166
+ :param model: Model architecture
167
+ :param target_layers: Layers on which GradCam should be executed
168
+ :param targets: Classes to be focused on for GradCam
169
+ :param number_of_samples: Number of images to print
170
+ :param transparency: Weight of Normal image when mixed with activations
171
+ """
172
+ # Plot configuration
173
+ fig = plt.figure(figsize=(10, 10))
174
+ x_count = 5
175
+ y_count = math.ceil(number_of_samples / x_count)
176
+
177
+ # Create an object for GradCam
178
+ cam = GradCAM(model=model, target_layers=target_layers)
179
+
180
+ # Iterate over number of specified images
181
+ for i in range(number_of_samples):
182
+ plt.subplot(y_count, x_count, i + 1)
183
+ input_tensor = data[i][0]
184
+
185
+ # Get the activations of the layer for the images
186
+ grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
187
+ grayscale_cam = grayscale_cam[0, :]
188
+
189
+ # Get back the original image
190
+ img = input_tensor.squeeze(0).to('cpu')
191
+ img = inv_normalize(img)
192
+ rgb_img = np.transpose(img, (1, 2, 0))
193
+ rgb_img = rgb_img.numpy().astype(np.float32)
194
+
195
+ # Ensure the image data is within the [0, 1] range
196
+ rgb_img = np.clip(rgb_img, 0, 1)
197
+
198
+ # Mix the activations on the original image
199
+ visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True, image_weight=transparency)
200
+
201
+ # Display the images on the plot
202
+ plt.imshow(visualization)
203
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
204
+ plt.xticks([])
205
+ plt.yticks([])
206
+
207
+ plt.tight_layout()
208
+
209
+ # Save the entire figure to a BytesIO object
210
+ buf = io.BytesIO()
211
+ plt.savefig(buf, format='png')
212
+ buf.seek(0)
213
+ img_var = Image.open(buf)
214
+
215
+ return img_var
216
+
visualize.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Function used for visualization of data and results
4
+ Author: Shilpaj Bhalerao
5
+ Date: Jun 21, 2023
6
+ """
7
+ # Standard Library Imports
8
+ import math
9
+ from dataclasses import dataclass
10
+ from typing import NoReturn
11
+ import io
12
+ from PIL import Image
13
+ # Third-Party Imports
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import pandas as pd
17
+ import seaborn as sn
18
+ import torch
19
+ import torch.nn as nn
20
+ from torchvision import transforms
21
+ from sklearn.metrics import confusion_matrix
22
+
23
+
24
+ # ---------------------------- DATA SAMPLES ----------------------------
25
+ def display_mnist_data_samples(dataset, number_of_samples: int) -> NoReturn:
26
+ """
27
+ Function to display samples for dataloader
28
+ :param dataset: Train or Test dataset transformed to Tensor
29
+ :param number_of_samples: Number of samples to be displayed
30
+ """
31
+ # Get batch from the data_set
32
+ batch_data = []
33
+ batch_label = []
34
+ for count, item in enumerate(dataset):
35
+ if not count <= number_of_samples:
36
+ break
37
+ batch_data.append(item[0])
38
+ batch_label.append(item[1])
39
+
40
+ # Plot the samples from the batch
41
+ fig = plt.figure()
42
+ x_count = 5
43
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
44
+
45
+ # Plot the samples from the batch
46
+ for i in range(number_of_samples):
47
+ plt.subplot(y_count, x_count, i + 1)
48
+ plt.tight_layout()
49
+ plt.imshow(batch_data[i].squeeze(), cmap='gray')
50
+ plt.title(batch_label[i])
51
+ plt.xticks([])
52
+ plt.yticks([])
53
+
54
+
55
+ def display_cifar_data_samples(data_set, number_of_samples: int, classes: list):
56
+ """
57
+ Function to display samples for data_set
58
+ :param data_set: Train or Test data_set transformed to Tensor
59
+ :param number_of_samples: Number of samples to be displayed
60
+ :param classes: Name of classes to be displayed
61
+ """
62
+ # Get batch from the data_set
63
+ batch_data = []
64
+ batch_label = []
65
+ for count, item in enumerate(data_set):
66
+ if not count <= number_of_samples:
67
+ break
68
+ batch_data.append(item[0])
69
+ batch_label.append(item[1])
70
+ batch_data = torch.stack(batch_data, dim=0).numpy()
71
+
72
+ # Plot the samples from the batch
73
+ fig = plt.figure()
74
+ x_count = 5
75
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
76
+
77
+ for i in range(number_of_samples):
78
+ plt.subplot(y_count, x_count, i + 1)
79
+ plt.tight_layout()
80
+ plt.imshow(np.transpose(batch_data[i].squeeze(), (1, 2, 0)))
81
+ plt.title(classes[batch_label[i]])
82
+ plt.xticks([])
83
+ plt.yticks([])
84
+
85
+
86
+ # ---------------------------- MISCLASSIFIED DATA ----------------------------
87
+ def display_cifar_misclassified_data(data: list,
88
+ classes,
89
+ inv_normalize: transforms.Normalize,
90
+ number_of_samples: int = 10):
91
+ """
92
+ Function to plot images with labels
93
+ :param data: List[Tuple(image, label)]
94
+ :param classes: Name of classes in the dataset
95
+ :param inv_normalize: Mean and Standard deviation values of the dataset
96
+ :param number_of_samples: Number of images to print
97
+ """
98
+ fig = plt.figure(figsize=(10, 10))
99
+
100
+ x_count = 5
101
+ y_count = math.ceil(number_of_samples / x_count)
102
+
103
+ for i in range(number_of_samples):
104
+ plt.subplot(y_count, x_count, i + 1)
105
+ img = data[i][0].squeeze().to('cpu')
106
+ img = inv_normalize(img)
107
+ plt.imshow(np.transpose(img, (1, 2, 0)))
108
+ plt.title(r"Correct: " + classes[data[i][1].item()] + '\n' + 'Output: ' + classes[data[i][2].item()])
109
+ plt.xticks([])
110
+ plt.yticks([])
111
+
112
+ plt.tight_layout()
113
+
114
+ # Save the entire figure to a BytesIO object
115
+ buftwo = io.BytesIO()
116
+ plt.savefig(buftwo, format='png')
117
+ buftwo.seek(0)
118
+ img_var = Image.open(buftwo)
119
+
120
+ # Optional: Close the buffer
121
+
122
+ return img_var
123
+
124
+ def display_mnist_misclassified_data(data: list,
125
+ number_of_samples: int = 10):
126
+ """
127
+ Function to plot images with labels
128
+ :param data: List[Tuple(image, label)]
129
+ :param number_of_samples: Number of images to print
130
+ """
131
+ fig = plt.figure(figsize=(8, 5))
132
+
133
+ x_count = 5
134
+ y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count)
135
+
136
+ for i in range(number_of_samples):
137
+ plt.subplot(y_count, x_count, i + 1)
138
+ img = data[i][0].squeeze(0).to('cpu')
139
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
140
+ plt.title(r"Correct: " + str(data[i][1].item()) + '\n' + 'Output: ' + str(data[i][2].item()))
141
+ plt.xticks([])
142
+ plt.yticks([])
143
+
144
+
145
+ # ---------------------------- AUGMENTATION SAMPLES ----------------------------
146
+ def visualize_cifar_augmentation(data_set, data_transforms):
147
+ """
148
+ Function to visualize the augmented data
149
+ :param data_set: Dataset without transformations
150
+ :param data_transforms: Dictionary of transforms
151
+ """
152
+ sample, label = data_set[6]
153
+ total_augmentations = len(data_transforms)
154
+
155
+ fig = plt.figure(figsize=(10, 5))
156
+ for count, (key, trans) in enumerate(data_transforms.items()):
157
+ if count == total_augmentations - 1:
158
+ break
159
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
160
+ augmented = trans(image=sample)['image']
161
+ plt.imshow(augmented)
162
+ plt.title(key)
163
+ plt.xticks([])
164
+ plt.yticks([])
165
+
166
+
167
+ def visualize_mnist_augmentation(data_set, data_transforms):
168
+ """
169
+ Function to visualize the augmented data
170
+ :param data_set: Dataset to visualize the augmentations
171
+ :param data_transforms: Dictionary of transforms
172
+ """
173
+ sample, label = data_set[6]
174
+ total_augmentations = len(data_transforms)
175
+
176
+ fig = plt.figure(figsize=(10, 5))
177
+ for count, (key, trans) in enumerate(data_transforms.items()):
178
+ if count == total_augmentations - 1:
179
+ break
180
+ plt.subplot(math.ceil(total_augmentations / 5), 5, count + 1)
181
+ img = trans(sample).to('cpu')
182
+ plt.imshow(np.transpose(img, (1, 2, 0)), cmap='gray')
183
+ plt.title(key)
184
+ plt.xticks([])
185
+ plt.yticks([])
186
+
187
+
188
+ # ---------------------------- LOSS AND ACCURACIES ----------------------------
189
+ def display_loss_and_accuracies(train_losses: list,
190
+ train_acc: list,
191
+ test_losses: list,
192
+ test_acc: list,
193
+ plot_size: tuple = (10, 10)) -> NoReturn:
194
+ """
195
+ Function to display training and test information(losses and accuracies)
196
+ :param train_losses: List containing training loss of each epoch
197
+ :param train_acc: List containing training accuracy of each epoch
198
+ :param test_losses: List containing test loss of each epoch
199
+ :param test_acc: List containing test accuracy of each epoch
200
+ :param plot_size: Size of the plot
201
+ """
202
+ # Create a plot of 2x2 of size
203
+ fig, axs = plt.subplots(2, 2, figsize=plot_size)
204
+
205
+ # Plot the training loss and accuracy for each epoch
206
+ axs[0, 0].plot(train_losses)
207
+ axs[0, 0].set_title("Training Loss")
208
+ axs[1, 0].plot(train_acc)
209
+ axs[1, 0].set_title("Training Accuracy")
210
+
211
+ # Plot the test loss and accuracy for each epoch
212
+ axs[0, 1].plot(test_losses)
213
+ axs[0, 1].set_title("Test Loss")
214
+ axs[1, 1].plot(test_acc)
215
+ axs[1, 1].set_title("Test Accuracy")
216
+
217
+
218
+ # ---------------------------- Feature Maps and Kernels ----------------------------
219
+
220
+ @dataclass
221
+ class ConvLayerInfo:
222
+ """
223
+ Data Class to store Conv layer's information
224
+ """
225
+ layer_number: int
226
+ weights: torch.nn.parameter.Parameter
227
+ layer_info: torch.nn.modules.conv.Conv2d
228
+
229
+
230
+ class FeatureMapVisualizer:
231
+ """
232
+ Class to visualize Feature Map of the Layers
233
+ """
234
+
235
+ def __init__(self, model):
236
+ """
237
+ Contructor
238
+ :param model: Model Architecture
239
+ """
240
+ self.conv_layers = []
241
+ self.outputs = []
242
+ self.layerwise_kernels = None
243
+
244
+ # Disect the model
245
+ counter = 0
246
+ model_children = model.children()
247
+ for children in model_children:
248
+ if type(children) == nn.Sequential:
249
+ for child in children:
250
+ if type(child) == nn.Conv2d:
251
+ counter += 1
252
+ self.conv_layers.append(ConvLayerInfo(layer_number=counter,
253
+ weights=child.weight,
254
+ layer_info=child)
255
+ )
256
+
257
+ def get_model_weights(self):
258
+ """
259
+ Method to get the model weights
260
+ """
261
+ model_weights = [layer.weights for layer in self.conv_layers]
262
+ return model_weights
263
+
264
+ def get_conv_layers(self):
265
+ """
266
+ Get the convolution layers
267
+ """
268
+ conv_layers = [layer.layer_info for layer in self.conv_layers]
269
+ return conv_layers
270
+
271
+ def get_total_conv_layers(self) -> int:
272
+ """
273
+ Get total number of convolution layers
274
+ """
275
+ out = self.get_conv_layers()
276
+ return len(out)
277
+
278
+ def feature_maps_of_all_kernels(self, image: torch.Tensor) -> dict:
279
+ """
280
+ Get feature maps from all the kernels of all the layers
281
+ :param image: Image to be passed to the network
282
+ """
283
+ image = image.unsqueeze(0)
284
+ image = image.to('cpu')
285
+
286
+ outputs = {}
287
+
288
+ layers = self.get_conv_layers()
289
+ for index, layer in enumerate(layers):
290
+ image = layer(image)
291
+ outputs[str(layer)] = image
292
+ self.outputs = outputs
293
+ return outputs
294
+
295
+ def visualize_feature_map_of_kernel(self, image: torch.Tensor, kernel_number: int) -> None:
296
+ """
297
+ Function to visualize feature map of kernel number from each layer
298
+ :param image: Image passed to the network
299
+ :param kernel_number: Number of kernel in each layer (Should be less than or equal to the minimum number of kernel in the network)
300
+ """
301
+ # List to store processed feature maps
302
+ processed = []
303
+
304
+ # Get feature maps from all kernels of all the conv layers
305
+ outputs = self.feature_maps_of_all_kernels(image)
306
+
307
+ # Extract the n_th kernel's output from each layer and convert it to grayscale
308
+ for feature_map in outputs.values():
309
+ try:
310
+ feature_map = feature_map[0][kernel_number]
311
+ except IndexError:
312
+ print("Filter number should be less than the minimum number of channels in a network")
313
+ break
314
+ finally:
315
+ gray_scale = feature_map / feature_map.shape[0]
316
+ processed.append(gray_scale.data.numpy())
317
+
318
+ # Plot the Feature maps with layer and kernel number
319
+ x_range = len(outputs) // 5 + 4
320
+ fig = plt.figure(figsize=(10, 10))
321
+ for i in range(len(processed)):
322
+ a = fig.add_subplot(x_range, 5, i + 1)
323
+ imgplot = plt.imshow(processed[i])
324
+ a.axis("off")
325
+ title = f"{list(outputs.keys())[i].split('(')[0]}_l{i + 1}_k{kernel_number}"
326
+ a.set_title(title, fontsize=10)
327
+
328
+ def get_max_kernel_number(self):
329
+ """
330
+ Function to get maximum number of kernels in the network (for a layer)
331
+ """
332
+ layers = self.get_conv_layers()
333
+ channels = [layer.out_channels for layer in layers]
334
+ self.layerwise_kernels = channels
335
+ return max(channels)
336
+
337
+ def visualize_kernels_from_layer(self, layer_number: int):
338
+ """
339
+ Visualize Kernels from a layer
340
+ :param layer_number: Number of layer from which kernels are to be visualized
341
+ """
342
+ # Get the kernels number for each layer
343
+ self.get_max_kernel_number()
344
+
345
+ # Zero Indexing
346
+ layer_number = layer_number - 1
347
+ _kernels = self.layerwise_kernels[layer_number]
348
+
349
+ grid = math.ceil(math.sqrt(_kernels))
350
+
351
+ plt.figure(figsize=(5, 4))
352
+ model_weights = self.get_model_weights()
353
+ _layer_weights = model_weights[layer_number].cpu()
354
+ for i, filter in enumerate(_layer_weights):
355
+ plt.subplot(grid, grid, i + 1)
356
+ plt.imshow(filter[0, :, :].detach(), cmap='gray')
357
+ plt.axis('off')
358
+ plt.show()
359
+
360
+
361
+ # ---------------------------- Confusion Matrix ----------------------------
362
+ def visualize_confusion_matrix(classes, device: str, model,
363
+ test_loader: torch.utils.data.DataLoader):
364
+ """
365
+ Function to generate and visualize confusion matrix
366
+ :param classes: List of class names
367
+ :param device: cuda/cpu
368
+ :param model: Model Architecture
369
+ :param test_loader: DataLoader for test set
370
+ """
371
+ nb_classes = len(classes)
372
+ device = 'cuda'
373
+ cm = torch.zeros(nb_classes, nb_classes)
374
+
375
+ model.eval()
376
+ with torch.no_grad():
377
+ for inputs, labels in test_loader:
378
+ inputs = inputs.to(device)
379
+ labels = labels.to(device)
380
+ model = model.to(device)
381
+
382
+ preds = model(inputs)
383
+ preds = preds.argmax(dim=1)
384
+
385
+ for t, p in zip(labels.view(-1), preds.view(-1)):
386
+ cm[t, p] = cm[t, p] + 1
387
+
388
+ # Build confusion matrix
389
+ labels = labels.to('cpu')
390
+ preds = preds.to('cpu')
391
+ cf_matrix = confusion_matrix(labels, preds)
392
+ df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None],
393
+ index=[i for i in classes],
394
+ columns=[i for i in classes])
395
+ plt.figure(figsize=(12, 7))
396
+ sn.heatmap(df_cm, annot=True)