UdayPrasad's picture
Create new file
88d4336
raw
history blame
No virus
2.18 kB
import scipy
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
import pickle
def mnist_prediction(test_image, model='KNN'):
test_image_flatten = test_image.reshape((-1, 28*28))
if model == 'KNN':
with open('KNN_best_model_final.pkl', 'rb') as file:
knn_loaded = pickle.load(file)
ans = knn_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'SoftMax':
with open('softmax_best_model_final.pkl', 'rb') as file:
softmax_model_loaded = pickle.load(file)
ans = softmax_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Deep Neural Network':
dnn_model = load_model("deep_nn_model_final.h5")
ans_prediction = dnn_model.predict(np.asarray(test_image_flatten))
ans = np.argmax(ans_prediction)
return ans
elif model == 'CNN':
cnn_model = load_model("cnn_model_final.h5")
ans_prediction = cnn_model.predict(np.asarray([test_image]))
ans = np.argmax(ans_prediction)
return ans
elif model == 'SVM':
with open('svm_best_model_final.pkl', 'rb') as file:
svm_model_loaded = pickle.load(file)
ans = svm_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Decision Tree':
with open('tree_model_final.pkl', 'rb') as file:
tree_model_loaded = pickle.load(file)
ans = tree_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Random Forest':
with open('forest_model_final.pkl', 'rb') as file:
forest_model_loaded = pickle.load(file)
ans = forest_model_loaded.predict(test_image_flatten)
return ans[0]
return "Not found"
input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest'])
output_label = gr.outputs.Textbox(label="Predicted Digit")
gr.Interface(fn=mnist_prediction,
inputs = [input_image, input_model],
outputs = output_label,
title = "MNIST classification",
).launch(debug=True)