File size: 2,183 Bytes
88d4336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import scipy
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import load_model
import pickle

def mnist_prediction(test_image, model='KNN'):
  test_image_flatten = test_image.reshape((-1, 28*28))
  if model == 'KNN':
    with open('KNN_best_model_final.pkl', 'rb') as file:
      knn_loaded = pickle.load(file)
    ans = knn_loaded.predict(test_image_flatten)
    return ans[0]
  elif model == 'SoftMax':
    with open('softmax_best_model_final.pkl', 'rb') as file:
      softmax_model_loaded = pickle.load(file)
    ans = softmax_model_loaded.predict(test_image_flatten)
    return ans[0]
  elif model == 'Deep Neural Network':
    dnn_model = load_model("deep_nn_model_final.h5")
    ans_prediction = dnn_model.predict(np.asarray(test_image_flatten))
    ans = np.argmax(ans_prediction)
    return ans
  elif model == 'CNN':
    cnn_model = load_model("cnn_model_final.h5")
    ans_prediction = cnn_model.predict(np.asarray([test_image]))
    ans = np.argmax(ans_prediction)
    return ans
  elif model == 'SVM':
    with open('svm_best_model_final.pkl', 'rb') as file:
      svm_model_loaded = pickle.load(file)
    ans = svm_model_loaded.predict(test_image_flatten)
    return ans[0]
  elif model == 'Decision Tree':
    with open('tree_model_final.pkl', 'rb') as file:
      tree_model_loaded = pickle.load(file)
    ans = tree_model_loaded.predict(test_image_flatten)
    return ans[0]
  elif model == 'Random Forest':
    with open('forest_model_final.pkl', 'rb') as file:
      forest_model_loaded = pickle.load(file)
    ans = forest_model_loaded.predict(test_image_flatten)
    return ans[0]
  return "Not found"

input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest'])

output_label = gr.outputs.Textbox(label="Predicted Digit")

gr.Interface(fn=mnist_prediction, 
                         inputs = [input_image, input_model], 
                         outputs = output_label,
                         title = "MNIST classification", 
                         ).launch(debug=True)