Spaces:
Runtime error
Runtime error
mrrahul011
commited on
Create misclassified_image.py
Browse files- misclassified_image.py +57 -0
misclassified_image.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
|
5 |
+
def get_cifar10_dataloaders(batch_size=64):
|
6 |
+
|
7 |
+
transform_test = transforms.Compose([
|
8 |
+
transforms.ToTensor(),
|
9 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
10 |
+
])
|
11 |
+
|
12 |
+
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
|
13 |
+
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
|
14 |
+
|
15 |
+
return testloader
|
16 |
+
|
17 |
+
#testloader = get_cifar10_dataloaders()
|
18 |
+
|
19 |
+
def get_misclassified_data(model, device, test_loader):
|
20 |
+
"""
|
21 |
+
Function to run the model on test set and return misclassified images
|
22 |
+
:param model: Network Architecture
|
23 |
+
:param device: CPU/GPU
|
24 |
+
:param test_loader: DataLoader for test set
|
25 |
+
"""
|
26 |
+
# Prepare the model for evaluation i.e. drop the dropout layer
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
# List to store misclassified Images
|
30 |
+
misclassified_data = []
|
31 |
+
|
32 |
+
# Reset the gradients
|
33 |
+
with torch.no_grad():
|
34 |
+
# Extract images, labels in a batch
|
35 |
+
for data, target in test_loader:
|
36 |
+
|
37 |
+
# Migrate the data to the device
|
38 |
+
data, target = data.to(device), target.to(device)
|
39 |
+
|
40 |
+
# Extract single image, label from the batch
|
41 |
+
for image, label in zip(data, target):
|
42 |
+
|
43 |
+
# Add batch dimension to the image
|
44 |
+
image = image.unsqueeze(0)
|
45 |
+
|
46 |
+
# Get the model prediction on the image
|
47 |
+
output = model(image)
|
48 |
+
|
49 |
+
# Convert the output from one-hot encoding to a value
|
50 |
+
pred = output.argmax(dim=1, keepdim=True)
|
51 |
+
|
52 |
+
# If prediction is incorrect, append the data
|
53 |
+
if pred != label:
|
54 |
+
misclassified_data.append((image, label, pred))
|
55 |
+
return misclassified_data
|
56 |
+
##################
|
57 |
+
#misclassified_data = get_misclassified_data(model,'cpu', testloader)
|