Praneeth383 commited on
Commit
6a7e2ee
1 Parent(s): 79c0d46

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from sklearn.neighbors import KNeighborsClassifier
5
+ from sklearn.tree import DecisionTreeClassifier
6
+ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
7
+ import joblib
8
+ import pickle
9
+
10
+ def fashion_MNIST_prediction(test_image, model='KNN'):
11
+ test_image_flatten = test_image.reshape((-1, 28*28))
12
+ fashion_mnist = tf.keras.datasets.fashion_mnist
13
+ (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
14
+ class_names = ("T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot")
15
+ img_shape = X_train.shape
16
+ n_samples = img_shape[0]
17
+ width = img_shape[1]
18
+ height = img_shape[2]
19
+ x_train_flatten = X_train.reshape(n_samples, width*height)
20
+
21
+ if model == 'KNN':
22
+ with open('knn_model.pkl', 'rb') as f:
23
+ knn = pickle.load(f)
24
+ ans = knn.predict(test_image_flatten)
25
+ ans_prediction = knn.predict_proba(test_image_flatten)
26
+ return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
27
+
28
+ elif model == 'DecisionTreeClassifier':
29
+ tree_model = joblib.load('tree_model.joblib')
30
+ ans = tree_model.predict(test_image_flatten)
31
+ ans_prediction = tree_model.predict_proba(test_image_flatten)
32
+ return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
33
+
34
+ elif model == 'RandomForestClassifier':
35
+ best_model = joblib.load('best_model.pkl')
36
+ ans = best_model.predict(test_image_flatten)
37
+ ans_prediction = best_model.predict_proba(test_image_flatten)
38
+ return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
39
+
40
+ elif model == 'AdaBoostClassifier':
41
+ best_estimator = joblib.load('best_adaboost_model.joblib')
42
+ ans = best_estimator.predict(test_image_flatten)
43
+ ans_prediction = best_estimator.predict_proba(test_image_flatten)
44
+ return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
45
+
46
+ elif model == 'GradientBoostingClassifier':
47
+ best_estimator = joblib.load('best_gbc_model.joblib')
48
+ ans = best_estimator.predict(test_image_flatten)
49
+ ans_prediction = best_estimator.predict_proba(test_image_flatten)
50
+ return class_names[ans[0]], dict(zip(class_names, map(float, ans_prediction[0])))
51
+
52
+ else:
53
+ return "Invalid Model Selection"
54
+
55
+ input_image = gr.inputs.Image(shape=(28, 28), image_mode='L')
56
+ input_model = gr.inputs.Dropdown(['KNN', 'DecisionTreeClassifier', 'RandomForestClassifier', 'AdaBoostClassifier', 'GradientBoostingClassifier'])
57
+
58
+ output_label = gr.outputs.Textbox(label="Predicted Label")
59
+ output_probability = gr.outputs.Label(num_top_classes=10, label="Predicted Probability Per Class")
60
+
61
+ gr.Interface(fn=fashion_MNIST_prediction,
62
+ inputs=[input_image, input_model],
63
+ outputs=[output_label, output_probability],
64
+ title="Fashion MNIST classification").launch(debug=True)