Spaces:
Runtime error
Runtime error
import scipy | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras.models import load_model | |
import pickle | |
def mnist_prediction(test_image, model='KNN'): | |
test_image_flatten = test_image.reshape((-1, 28*28)) | |
if model == 'KNN': | |
with open('KNN_best_model_final.pkl', 'rb') as file: | |
knn_loaded = pickle.load(file) | |
ans = knn_loaded.predict(test_image_flatten) | |
return ans[0] | |
elif model == 'SoftMax': | |
with open('softmax_best_model_final.pkl', 'rb') as file: | |
softmax_model_loaded = pickle.load(file) | |
ans = softmax_model_loaded.predict(test_image_flatten) | |
return ans[0] | |
elif model == 'Deep Neural Network': | |
dnn_model = load_model("deep_nn_model_final.h5") | |
ans_prediction = dnn_model.predict(np.asarray(test_image_flatten)) | |
ans = np.argmax(ans_prediction) | |
return ans | |
elif model == 'CNN': | |
cnn_model = load_model("cnn_model_final.h5") | |
ans_prediction = cnn_model.predict(np.asarray([test_image])) | |
ans = np.argmax(ans_prediction) | |
return ans | |
elif model == 'SVM': | |
with open('svm_best_model_final.pkl', 'rb') as file: | |
svm_model_loaded = pickle.load(file) | |
ans = svm_model_loaded.predict(test_image_flatten) | |
return ans[0] | |
elif model == 'Decision Tree': | |
with open('tree_model_final.pkl', 'rb') as file: | |
tree_model_loaded = pickle.load(file) | |
ans = tree_model_loaded.predict(test_image_flatten) | |
return ans[0] | |
elif model == 'Random Forest': | |
with open('forest_model_final.pkl', 'rb') as file: | |
forest_model_loaded = pickle.load(file) | |
ans = forest_model_loaded.predict(test_image_flatten) | |
return ans[0] | |
return "Not found" | |
input_image = gr.inputs.Image(shape=(28, 28), image_mode='L') | |
input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest']) | |
output_label = gr.outputs.Textbox(label="Predicted Digit") | |
gr.Interface(fn=mnist_prediction, | |
inputs = [input_image, input_model], | |
outputs = output_label, | |
title = "MNIST classification", | |
).launch(debug=True) |