shivi commited on
Commit
dfd9dbc
1 Parent(s): b256351

Update utils/predict.py

Browse files
Files changed (1) hide show
  1. utils/predict.py +4 -4
utils/predict.py CHANGED
@@ -61,19 +61,19 @@ def predict_batch(image_path):
61
  predictions_df = pd.DataFrame()
62
  num = random.randint(0,50)
63
  for images, labels in slice:
64
- for i in range(6):
65
- ax = plt.subplot(3, 3, i + 1)
66
  plt.imshow(images[i].numpy().astype("uint8"))
67
  output = np.argmax(slice_pred[i])
68
 
69
  prob_list = slice_pred[i].numpy().flatten().tolist()
70
  sorted_prob = np.argsort(slice_pred[i])[::-1].flatten()
71
- prob_scores = {"image": "image "+ str(i), "first": f"predicted {class_vocab[sorted_prob[0]]} with {round(prob_list[sorted_prob[0]] * 100,2)}% confidence",
72
  "second": f"predicted {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)}% confidence",
73
  "third": f"predicted {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)}% confidence"}
74
  predictions_df = predictions_df.append(prob_scores,ignore_index=True)
75
 
76
- plt.title(f"image {i} : {class_vocab[output]}")
77
  plt.axis("off")
78
  plt.savefig(saved_plot,bbox_inches='tight')
79
 
 
61
  predictions_df = pd.DataFrame()
62
  num = random.randint(0,50)
63
  for images, labels in slice:
64
+ for i, j in zip(range(num,num+6), range(6)):
65
+ ax = plt.subplot(3, 3, j + 1)
66
  plt.imshow(images[i].numpy().astype("uint8"))
67
  output = np.argmax(slice_pred[i])
68
 
69
  prob_list = slice_pred[i].numpy().flatten().tolist()
70
  sorted_prob = np.argsort(slice_pred[i])[::-1].flatten()
71
+ prob_scores = {"image": "image "+ str(j), "first": f"predicted {class_vocab[sorted_prob[0]]} with {round(prob_list[sorted_prob[0]] * 100,2)}% confidence",
72
  "second": f"predicted {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)}% confidence",
73
  "third": f"predicted {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)}% confidence"}
74
  predictions_df = predictions_df.append(prob_scores,ignore_index=True)
75
 
76
+ plt.title(f"image {j} : {class_vocab[output]}")
77
  plt.axis("off")
78
  plt.savefig(saved_plot,bbox_inches='tight')
79