Spaces:
Runtime error
Runtime error
File size: 2,183 Bytes
88d4336 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import scipy
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
import pickle
def mnist_prediction(test_image, model='KNN'):
test_image_flatten = test_image.reshape((-1, 28*28))
if model == 'KNN':
with open('KNN_best_model_final.pkl', 'rb') as file:
knn_loaded = pickle.load(file)
ans = knn_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'SoftMax':
with open('softmax_best_model_final.pkl', 'rb') as file:
softmax_model_loaded = pickle.load(file)
ans = softmax_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Deep Neural Network':
dnn_model = load_model("deep_nn_model_final.h5")
ans_prediction = dnn_model.predict(np.asarray(test_image_flatten))
ans = np.argmax(ans_prediction)
return ans
elif model == 'CNN':
cnn_model = load_model("cnn_model_final.h5")
ans_prediction = cnn_model.predict(np.asarray([test_image]))
ans = np.argmax(ans_prediction)
return ans
elif model == 'SVM':
with open('svm_best_model_final.pkl', 'rb') as file:
svm_model_loaded = pickle.load(file)
ans = svm_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Decision Tree':
with open('tree_model_final.pkl', 'rb') as file:
tree_model_loaded = pickle.load(file)
ans = tree_model_loaded.predict(test_image_flatten)
return ans[0]
elif model == 'Random Forest':
with open('forest_model_final.pkl', 'rb') as file:
forest_model_loaded = pickle.load(file)
ans = forest_model_loaded.predict(test_image_flatten)
return ans[0]
return "Not found"
input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest'])
output_label = gr.outputs.Textbox(label="Predicted Digit")
gr.Interface(fn=mnist_prediction,
inputs = [input_image, input_model],
outputs = output_label,
title = "MNIST classification",
).launch(debug=True) |