s12 / utils /utils.py
srikanthp07's picture
Upload 27 files
9022436
import torch
# import torch.nn as nn
import torch.nn.functional as F
# import torch.optim as optim
import numpy as np
# from tqdm import tqdm
import matplotlib.pyplot as plt
def return_dataset_images(train_loader, total_images):
batch_data, batch_label = next(iter(train_loader))
fig = plt.figure()
for i in range(total_images):
plt.subplot(3,4,i+1)
plt.tight_layout()
# plt.imshow(batch_data[i].squeeze(0), cmap='gray')
plt.imshow(batch_data[i].permute(1,2,0), cmap='gray')
plt.title(batch_label[i].item())
plt.xticks([])
plt.yticks([])
def GetCorrectPredCount(pPrediction, pLabels):
return pPrediction.argmax(dim=1).eq(pLabels).sum().item()
def get_incorrrect_predictions(model, loader, device):
"""Get all incorrect predictions
Args:
model (Net): Trained model
loader (DataLoader): instance of data loader
device (str): Which device to use cuda/cpu
Returns:
list: list of all incorrect predictions and their corresponding details
"""
model.eval()
incorrect = []
with torch.no_grad():
for data, target in loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
pred = output.argmax(dim=1)
for d, t, p, o in zip(data, target, pred, output):
if p.eq(t.view_as(p)).item() == False:
incorrect.append(
[d.cpu(), t.cpu(), p.cpu(), o[p.item()].cpu()])
return incorrect
def plot_incorrect_predictions(predictions, class_map, count=10):
"""Plot Incorrect predictions
Args:
predictions (list): List of all incorrect predictions
class_map (dict): Lable mapping
count (int, optional): Number of samples to print, multiple of 5. Defaults to 10.
"""
print(f'Total Incorrect Predictions {len(predictions)}')
if not count % 5 == 0:
print("Count should be multiple of 10")
return
classes = list(class_map.values())
fig = plt.figure(figsize=(10, 5))
for i, (d, t, p, o) in enumerate(predictions):
ax = fig.add_subplot(int(count/5), 5, i + 1, xticks=[], yticks=[])
ax.set_title(f'{classes[t.item()]}/{classes[p.item()]}')
plt.imshow(d.cpu().numpy().transpose(1, 2, 0))
if i+1 == 5*(count/5):
break
def wrong_predictions(model,test_loader, norm_mean, norm_std, classes, device):
wrong_images=[]
wrong_label=[]
correct_label=[]
correct_images=[]
correct_images_labels=[]
model.eval()
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True).squeeze() # get the index of the max log-probability
wrong_pred = (pred.eq(target.view_as(pred)) == False)
wrong_images.append(data[wrong_pred])
wrong_label.append(pred[wrong_pred])
correct_label.append(target.view_as(pred)[wrong_pred])
# wrong_pred = (pred.eq(target.view_as(pred)) == False)
correct_images.append(data)
correct_images_labels.append(pred)
wrong_predictions = list(zip(torch.cat(wrong_images),torch.cat(wrong_label),torch.cat(correct_label)))
all_predictions = list(zip(torch.cat(correct_images),torch.cat(correct_images_labels),torch.cat(correct_images_labels)))
if len(wrong_predictions)>100:
break
print(f'Total wrong predictions are {len(wrong_predictions)}')
# plot_misclassified(wrong_predictions, norm_mean, norm_std, classes)
return wrong_predictions, all_predictions
def plot_misclassified(wrong_predictions, norm_mean, norm_std, classes):
fig = plt.figure(figsize=(10,12))
fig.tight_layout()
for i, (img, pred, correct) in enumerate(wrong_predictions[:20]):
img, pred, target = img.cpu().numpy().astype(dtype=np.float32), pred.cpu(), correct.cpu()
for j in range(img.shape[0]):
img[j] = (img[j]*norm_std[j])+norm_mean[j]
img = np.transpose(img, (1, 2, 0)) #/ 2 + 0.5
ax = fig.add_subplot(5, 5, i+1)
ax.axis('off')
ax.set_title(f'\nactual : {classes[target.item()]}\npredicted : {classes[pred.item()]}',fontsize=10)
ax.imshow(img)
plt.show()