|
|
|
|
|
|
|
def demo1_derive_MNIST_train_test_data(): |
|
from sklearn.datasets import fetch_openml |
|
import numpy as np |
|
mnist = fetch_openml('mnist_784', version=1, as_frame=False) |
|
X, y = mnist["data"], mnist["target"] |
|
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:] |
|
y_train = y_train.astype(np.uint8) |
|
y_test = y_test.astype(np.uint8) |
|
return X_train, X_test, y_train, y_test |
|
|
|
X_train, X_test, y_train, y_test = demo1_derive_MNIST_train_test_data() |
|
print("X_train.shape: ", X_train.shape) |
|
print("X_test.shape: ", X_test.shape) |
|
print("y_train.shape: ", y_train.shape) |
|
print("y_test.shape: ", y_test.shape) |
|
|
|
train_features = X_train |
|
train_labels = y_train |
|
test_feature = X_test[0] |
|
K = 3 |
|
print("train_features: ",train_features.shape) |
|
print("train_labels: ",train_labels.shape) |
|
print("test_feature: ",test_feature.shape) |
|
|
|
|
|
|
|
import scipy |
|
import gradio as gr |
|
import numpy as np |
|
import cv2 |
|
import os |
|
|
|
def get_sample_images(num_images): |
|
sample_images = [] |
|
for i in range(num_images): |
|
test_feature = X_test[i] |
|
test_feature_2d =test_feature.reshape(28,28) |
|
|
|
|
|
data = test_feature_2d.astype(np.uint8) |
|
|
|
outdir = "images_folder" |
|
img_path = os.path.join(outdir, 'local_%05d.png' % (i,)) |
|
if not os.path.exists(outdir): |
|
os.mkdir(outdir) |
|
cv2.imwrite(img_path, data) |
|
|
|
sample_images.append([img_path,int(np.random.choice([7,9,11,13,15,24]))]) |
|
return sample_images |
|
|
|
|
|
def plot_digits(instances, images_per_row=3): |
|
import matplotlib.pyplot as plt |
|
import matplotlib as mpl |
|
size = 28 |
|
images_per_row = min(len(instances), images_per_row) |
|
|
|
n_rows = (len(instances) - 1) // images_per_row + 1 |
|
|
|
n = len(instances) |
|
|
|
fig = plt.figure(figsize=(15,8)) |
|
for i in range(len(instances)): |
|
|
|
fig.add_subplot(n_rows, images_per_row, i + 1) |
|
|
|
plt.imshow(instances[i].reshape(size,size), cmap = mpl.cm.binary) |
|
plt.axis("off") |
|
plt.title("Neighbor "+str(i+1), size=20) |
|
fig.tight_layout() |
|
|
|
plt.savefig('results.png', dpi=300) |
|
return 'results.png' |
|
|
|
|
|
|
|
def KNN_predict(train_features, train_labels, test_feature, K): |
|
label_record = [] |
|
for i in range(len(train_features)): |
|
train_point_feature = train_features[i] |
|
test_point_feature = test_feature.flatten() |
|
|
|
|
|
|
|
dis = scipy.spatial.distance.euclidean(train_point_feature, test_point_feature) |
|
|
|
|
|
y = train_labels[i] |
|
label_record.append((dis, y, train_point_feature)) |
|
|
|
|
|
from operator import itemgetter |
|
sorted_labels = sorted(label_record,key=itemgetter(0)) |
|
|
|
major_class = [] |
|
neighbor_imgs = [] |
|
for k in range(K): |
|
major_class.append(sorted_labels[k][1]) |
|
|
|
|
|
if k <24: |
|
neighbor_feature = sorted_labels[k][2] |
|
neighbor_imgs.append(neighbor_feature) |
|
|
|
|
|
final_prediction = scipy.stats.mode(major_class) |
|
|
|
|
|
class_freq = {} |
|
for i in range(0,10): |
|
class_freq['Digit '+str(i)] = float(major_class.count(i)) / len(major_class) |
|
|
|
|
|
neighbor_imgs =np.array(neighbor_imgs) |
|
image_path = plot_digits(neighbor_imgs, images_per_row=6) |
|
|
|
return final_prediction, class_freq, image_path |
|
|
|
|
|
def call_our_KNN(test_image, K=7): |
|
test_image_flatten = test_image.reshape((-1, 28*28)) |
|
y_pred_each, y_prob_each, image_path = KNN_predict(train_features, train_labels, test_image_flatten, int(K)) |
|
return y_pred_each, y_prob_each, image_path |
|
|
|
|
|
|
|
sample_images = get_sample_images(10) |
|
|
|
|
|
set_image = gr.inputs.Image(shape=(28, 28), image_mode='L') |
|
set_K = gr.inputs.Slider(1, 24, step=1, default=7) |
|
|
|
set_label = gr.outputs.Textbox(label="Predicted Digit") |
|
|
|
|
|
set_probability = gr.outputs.Label(num_top_classes=10, label="Predicted Probability Per Class") |
|
|
|
set_out_images = gr.outputs.Image(label="Closest Neighbors") |
|
|
|
|
|
|
|
interface = gr.Interface(fn=call_our_KNN, |
|
inputs=[set_image, set_K], |
|
outputs=[set_label,set_probability,set_out_images], |
|
examples_per_page = 2, |
|
examples = sample_images, |
|
title="CSCI4750/5750 Demo 1: Digit classification using KNN algorithm", |
|
description= "Click examples below for a quick demo", |
|
theme = 'huggingface', |
|
layout = 'vertical' |
|
) |
|
interface.launch(debug=True) |