FASHION_IMAGES / app.py
Praneeth383's picture
Create app.py
6a7e2ee
raw
history blame
No virus
3.06 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
import joblib
import pickle
def fashion_MNIST_prediction(test_image, model='KNN'):
test_image_flatten = test_image.reshape((-1, 28*28))
fashion_mnist = tf.keras.datasets.fashion_mnist
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
class_names = ("T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot")
img_shape = X_train.shape
n_samples = img_shape[0]
width = img_shape[1]
height = img_shape[2]
x_train_flatten = X_train.reshape(n_samples, width*height)
if model == 'KNN':
with open('knn_model.pkl', 'rb') as f:
knn = pickle.load(f)
ans = knn.predict(test_image_flatten)
ans_prediction = knn.predict_proba(test_image_flatten)
return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
elif model == 'DecisionTreeClassifier':
tree_model = joblib.load('tree_model.joblib')
ans = tree_model.predict(test_image_flatten)
ans_prediction = tree_model.predict_proba(test_image_flatten)
return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
elif model == 'RandomForestClassifier':
best_model = joblib.load('best_model.pkl')
ans = best_model.predict(test_image_flatten)
ans_prediction = best_model.predict_proba(test_image_flatten)
return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
elif model == 'AdaBoostClassifier':
best_estimator = joblib.load('best_adaboost_model.joblib')
ans = best_estimator.predict(test_image_flatten)
ans_prediction = best_estimator.predict_proba(test_image_flatten)
return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
elif model == 'GradientBoostingClassifier':
best_estimator = joblib.load('best_gbc_model.joblib')
ans = best_estimator.predict(test_image_flatten)
ans_prediction = best_estimator.predict_proba(test_image_flatten)
return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
else:
return "Invalid Model Selection"
input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
input_model = gr.inputs.Dropdown(['KNN', 'DecisionTreeClassifier', 'RandomForestClassifier', 'AdaBoostClassifier', 'GradientBoostingClassifier'])
output_label = gr.outputs.Textbox(label="Predicted Label")
output_probability = gr.outputs.Label(num_top_classes=10, label="Predicted Probability Per Class")
gr.Interface(fn=fashion_MNIST_prediction,
inputs=[input_image, input_model],
outputs=[output_label, output_probability],
title="Fashion MNIST classification").launch(debug=True)