mrrahul011 commited on
Commit
27a4c47
·
verified ·
1 Parent(s): f5b7f38

Create misclassified_image.py

Browse files
Files changed (1) hide show
  1. misclassified_image.py +57 -0
misclassified_image.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torchvision.transforms as transforms
4
+
5
+ def get_cifar10_dataloaders(batch_size=64):
6
+
7
+ transform_test = transforms.Compose([
8
+ transforms.ToTensor(),
9
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
10
+ ])
11
+
12
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
13
+ testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
14
+
15
+ return testloader
16
+
17
+ #testloader = get_cifar10_dataloaders()
18
+
19
+ def get_misclassified_data(model, device, test_loader):
20
+ """
21
+ Function to run the model on test set and return misclassified images
22
+ :param model: Network Architecture
23
+ :param device: CPU/GPU
24
+ :param test_loader: DataLoader for test set
25
+ """
26
+ # Prepare the model for evaluation i.e. drop the dropout layer
27
+ model.eval()
28
+
29
+ # List to store misclassified Images
30
+ misclassified_data = []
31
+
32
+ # Reset the gradients
33
+ with torch.no_grad():
34
+ # Extract images, labels in a batch
35
+ for data, target in test_loader:
36
+
37
+ # Migrate the data to the device
38
+ data, target = data.to(device), target.to(device)
39
+
40
+ # Extract single image, label from the batch
41
+ for image, label in zip(data, target):
42
+
43
+ # Add batch dimension to the image
44
+ image = image.unsqueeze(0)
45
+
46
+ # Get the model prediction on the image
47
+ output = model(image)
48
+
49
+ # Convert the output from one-hot encoding to a value
50
+ pred = output.argmax(dim=1, keepdim=True)
51
+
52
+ # If prediction is incorrect, append the data
53
+ if pred != label:
54
+ misclassified_data.append((image, label, pred))
55
+ return misclassified_data
56
+ ##################
57
+ #misclassified_data = get_misclassified_data(model,'cpu', testloader)