Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import cv2 | |
| import numpy as np | |
| from joblib import load | |
| import gradio as gr | |
| #Create user inputs | |
| input_modules = [gr.components.Image(label = "Input Image"), | |
| gr.components.Dropdown(label = "Pick a Model", | |
| choices = ["Quadratic Disciminat Analysis", | |
| "Gaussian Naive Bayes Classifier", | |
| "K-Nearest-Neighbors", | |
| "Linear discriminant Analysis"])] | |
| #Create outputs | |
| output_modules = [gr.components.Textbox(label = "Prediction"), | |
| gr.components.Label(label = "Prediction Probs")] | |
| #Gradio function | |
| def classifier_picker(input_img, input_model): | |
| #initalizes some starting vars | |
| class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] | |
| output1 = 0 | |
| output2 = dict([(class_name, 0) for class_name in class_names]) | |
| #Takes the chosen model and loads it | |
| if input_model == "Quadratic Disciminat Analysis": | |
| loaded_model = load('QDA_save.joblib') | |
| elif input_model == "Gaussian Naive Bayes Classifier": | |
| loaded_model = load('GNB_save.joblib') | |
| elif input_model == "Linear discriminant Analysis": | |
| loaded_model = load('fashionMNIST_LDA.joblib') | |
| else: | |
| loaded_model = load('KNN_fashionMNIST.joblib') | |
| #shapes the image into the input size | |
| reshaped_img = cv2.resize(input_img, (28,28)) | |
| #since our model works with gray images, we need to convert the input image to gray image | |
| grayscale_img = cv2.cvtColor(reshaped_img, cv2.COLOR_BGR2GRAY) | |
| #we need to flatten the image to work with out model | |
| flattened_img = np.array(grayscale_img).reshape(784) | |
| #prediction of the image | |
| output1 = loaded_model.predict([flattened_img]) | |
| output2 = dict([(class_name, prob) for class_name, prob in zip(class_names, loaded_model.predict_proba([flattened_img])[0])]) | |
| return class_names[output1[0]], output2 | |
| #Launching the module | |
| gr.Interface(fn=classifier_picker, inputs=input_modules, outputs=output_modules,).launch() |