UdayPrasad commited on
Commit
88d4336
1 Parent(s): 057afb4

Create new file

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import gradio as gr
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ from keras.models import load_model
7
+ import pickle
8
+
9
+ def mnist_prediction(test_image, model='KNN'):
10
+ test_image_flatten = test_image.reshape((-1, 28*28))
11
+ if model == 'KNN':
12
+ with open('KNN_best_model_final.pkl', 'rb') as file:
13
+ knn_loaded = pickle.load(file)
14
+ ans = knn_loaded.predict(test_image_flatten)
15
+ return ans[0]
16
+ elif model == 'SoftMax':
17
+ with open('softmax_best_model_final.pkl', 'rb') as file:
18
+ softmax_model_loaded = pickle.load(file)
19
+ ans = softmax_model_loaded.predict(test_image_flatten)
20
+ return ans[0]
21
+ elif model == 'Deep Neural Network':
22
+ dnn_model = load_model("deep_nn_model_final.h5")
23
+ ans_prediction = dnn_model.predict(np.asarray(test_image_flatten))
24
+ ans = np.argmax(ans_prediction)
25
+ return ans
26
+ elif model == 'CNN':
27
+ cnn_model = load_model("cnn_model_final.h5")
28
+ ans_prediction = cnn_model.predict(np.asarray([test_image]))
29
+ ans = np.argmax(ans_prediction)
30
+ return ans
31
+ elif model == 'SVM':
32
+ with open('svm_best_model_final.pkl', 'rb') as file:
33
+ svm_model_loaded = pickle.load(file)
34
+ ans = svm_model_loaded.predict(test_image_flatten)
35
+ return ans[0]
36
+ elif model == 'Decision Tree':
37
+ with open('tree_model_final.pkl', 'rb') as file:
38
+ tree_model_loaded = pickle.load(file)
39
+ ans = tree_model_loaded.predict(test_image_flatten)
40
+ return ans[0]
41
+ elif model == 'Random Forest':
42
+ with open('forest_model_final.pkl', 'rb') as file:
43
+ forest_model_loaded = pickle.load(file)
44
+ ans = forest_model_loaded.predict(test_image_flatten)
45
+ return ans[0]
46
+ return "Not found"
47
+
48
+ input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
49
+ input_model = gr.inputs.Dropdown(['KNN', 'SoftMax', 'Deep Neural Network', 'CNN', 'SVM', 'Decision Tree', 'Random Forest'])
50
+
51
+ output_label = gr.outputs.Textbox(label="Predicted Digit")
52
+
53
+ gr.Interface(fn=mnist_prediction,
54
+ inputs = [input_image, input_model],
55
+ outputs = output_label,
56
+ title = "MNIST classification",
57
+ ).launch(debug=True)