import numpy as np import tensorflow as tf from tensorflow import keras from huggingface_hub import from_pretrained_keras from .lr_schedule import WarmUpCosine from .constants import Config, class_vocab from keras.utils import load_img, img_to_array from tensorflow_addons.optimizers import AdamW import matplotlib.pyplot as plt import pandas as pd import random config = Config() ##Load Model model = from_pretrained_keras("keras-io/shiftvit", custom_objects={"WarmUpCosine":WarmUpCosine, "AdamW": AdamW}) (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() AUTO = tf.data.AUTOTUNE def predict(image_path): """ This function is used for fetching predictions corresponding to input_image. It outputs confidence scores corresponding to each class on which the model was trained """ test_image1 = load_img(image_path,target_size =(32,32)) test_image = img_to_array(test_image1) test_image = np.expand_dims(test_image, axis =0) test_image = test_image.astype('uint8') predict_ds = tf.data.Dataset.from_tensor_slices(test_image) predict_ds = predict_ds.shuffle(config.buffer_size).batch(config.batch_size).prefetch(AUTO) logits = model.predict(predict_ds) prob = tf.nn.softmax(logits) confidences = {} prob_list = prob.numpy().flatten().tolist() sorted_prob = np.argsort(prob)[::-1].flatten() for i in sorted_prob: confidences[class_vocab[i]] = float(prob_list[i]) return confidences def predict_batch(image_path): test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) test_ds = test_ds.batch(config.batch_size).prefetch(AUTO) slice = test_ds.take(1) slice_pred = model.predict(slice) slice_pred = tf.nn.softmax(slice_pred) saved_plot = "plot.jpg" fig = plt.figure() predictions_df = pd.DataFrame() num = random.randint(0,50) for images, labels in slice: for i, j in zip(range(num,num+6), range(6)): ax = plt.subplot(3, 3, j + 1) plt.imshow(images[i].numpy().astype("uint8")) output = np.argmax(slice_pred[i]) prob_list = slice_pred[i].numpy().flatten().tolist() sorted_prob = np.argsort(slice_pred[i])[::-1].flatten() prob_scores = {"image": "image "+ str(j), "first": f"predicted {class_vocab[sorted_prob[0]]} with {round(prob_list[sorted_prob[0]] * 100,2)}% confidence", "second": f"predicted {class_vocab[sorted_prob[1]]} is {round(prob_list[sorted_prob[1]] * 100,2)}% confidence", "third": f"predicted {class_vocab[sorted_prob[2]]} is {round(prob_list[sorted_prob[2]] * 100,2)}% confidence"} predictions_df = predictions_df.append(prob_scores,ignore_index=True) plt.title(f"image {j} : {class_vocab[output]}") plt.axis("off") plt.savefig(saved_plot,bbox_inches='tight') return saved_plot, predictions_df