shiftvit / utils /predict.py
shivi's picture
Update utils/predict.py
dfd9dbc
raw
history blame contribute delete
No virus
2.98 kB
import numpy as np
import tensorflow as tf
from tensorflow import keras
from huggingface_hub import from_pretrained_keras
from .lr_schedule import WarmUpCosine
from .constants import Config, class_vocab
from keras.utils import load_img, img_to_array
from tensorflow_addons.optimizers import AdamW
import matplotlib.pyplot as plt
import pandas as pd
import random
config = Config()
##Load Model
model = from_pretrained_keras("keras-io/shiftvit", custom_objects={"WarmUpCosine":WarmUpCosine, "AdamW": AdamW})
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
AUTO = tf.data.AUTOTUNE
def predict(image_path):
"""
This function is used for fetching predictions corresponding to input_image.
It outputs confidence scores corresponding to each class on which the model was trained
"""
test_image1 = load_img(image_path,target_size =(32,32))
test_image = img_to_array(test_image1)
test_image = np.expand_dims(test_image, axis =0)
test_image = test_image.astype('uint8')
predict_ds = tf.data.Dataset.from_tensor_slices(test_image)
predict_ds = predict_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO)
logits = model.predict(predict_ds)
prob = tf.nn.softmax(logits)
confidences = {}
prob_list = prob.numpy().flatten().tolist()
sorted_prob = np.argsort(prob)[::-1].flatten()
for i in sorted_prob:
confidences[class_vocab[i]] = float(prob_list[i])
return confidences
def predict_batch(image_path):
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(config.batch_size).prefetch(AUTO)
slice = test_ds.take(1)
slice_pred = model.predict(slice)
slice_pred = tf.nn.softmax(slice_pred)
saved_plot = "plot.jpg"
fig = plt.figure()
predictions_df = pd.DataFrame()
num = random.randint(0,50)
for images, labels in slice:
for i, j in zip(range(num,num+6), range(6)):
ax = plt.subplot(3, 3, j + 1)
plt.imshow(images[i].numpy().astype("uint8"))
output = np.argmax(slice_pred[i])
prob_list = slice_pred[i].numpy().flatten().tolist()
sorted_prob = np.argsort(slice_pred[i])[::-1].flatten()
prob_scores = {"image": "image "+ str(j), "first": f"predicted {class_vocab[sorted_prob[0]]} with {round(prob_list[sorted_prob[0]] * 100,2)}% confidence",
"second": f"predicted {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)}% confidence",
"third": f"predicted {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)}% confidence"}
predictions_df = predictions_df.append(prob_scores,ignore_index=True)
plt.title(f"image {j} : {class_vocab[output]}")
plt.axis("off")
plt.savefig(saved_plot,bbox_inches='tight')
return saved_plot, predictions_df