ravali-maddela's picture
Update app.py
72e69fc verified
raw
history blame
1.35 kB
from json import load
import gradio as gr
import tensorflow as tf
input_module1 = gr.Image(label = "test_image", image_mode='L')
input_module2 = gr.Dropdown(choices=['KNN', 'Softmax'], label = "Select Algorithm")
output_module1 = gr.Textbox(label = "Predicted Class")
output_module2 = gr.Label(label = "Predict Probability")
def fashion_images(input1, input2):
from PIL import Image
img_pil = Image.fromarray(input1)
img_28x28 = np.array(img_pil.resize((28, 28), Image.ANTIALIAS))
numpy_image = img_28x28.reshape(1, 28*28) # this will reshape input image into numpy array with shape (1, 784)
print("numpy_image: ", numpy_image.shape)
if input2 == "KNN":
model = joblib.load('best_knn_model.joblib')
else:
model = joblib.load('best_logistic_model.joblib')
out = model.predict(numpy_image)[0]
out_prob = model.predict_proba(numpy_image)[0]
print("out: ",out)
print("out_prob: ",out_prob.shape)
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
final = class_names[out]
class_prob = { class_names[i]:prob for i, prob in enumerate(out_prob)}
return final, class_prob
gr.Interface(fn=fashion_images, inputs=[input_module1, input_module2], outputs=[output_module1,output_module2]).launch(debug=True)