MuskanMjn's picture
Update app.py
46378b2
raw
history blame
3.11 kB
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
import gradio as gr
from PIL import Image
def calculate_score(clf):
xx, yy = np.meshgrid(np.linspace(-3, 3, 500), np.linspace(-3, 3, 500))
X_test = np.c_[xx.ravel(), yy.ravel()]
Y_test = np.logical_xor(xx.ravel() > 0, yy.ravel() > 0)
return clf.score(X_test, Y_test)
def getColorMap(kernel, gamma):
# prepare the training dataset
np.random.seed(0)
X = np.random.randn(300, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0)
# fit the model
clf = svm.NuSVC(kernel=kernel, gamma=gamma)
clf.fit(X, Y)
#create a grid for the plotting the decision function
xx, yy = np.meshgrid(np.linspace(-3, 3, 500), np.linspace(-3, 3, 500))
# plot the decision function for each datapoint on the grid
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.imshow(
Z,
interpolation="nearest",
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
aspect="auto",
origin="lower",
cmap=plt.cm.PuOr_r,
)
contours = plt.contour(xx, yy, Z, levels=[0], linewidths=2, linestyles="dashed")
plt.scatter(X[:, 0], X[:, 1], s=30, c=Y, cmap=plt.cm.Paired, edgecolors='k')
plt.title(f"Decision function for Non-Linear SVC with the {kernel} kernel and '{gamma}' gamma ", fontsize='14') #title
plt.xlabel("X",fontsize='13') #adds a label in the x axis
plt.ylabel("Y",fontsize='13') #adds a label in the y axis
return plt, calculate_score(clf)
with gr.Blocks() as demo:
gr.Markdown("## Learning the XOR function: An application of Binary Classification using Non-linear SVM")
gr.Markdown("### This demo is based on this [scikit-learn example](https://scikit-learn.org/stable/auto_examples/svm/plot_svm_nonlinear.html#sphx-glr-auto-examples-svm-plot-svm-nonlinear-py).")
gr.Markdown("### In this demo, we use a non-linear SVC (Support Vector Classifier) to learn the decision function of the XOR operator.")
xor_image = Image.open("xor.png")
gr.Image(xor_image, label="Table explaining the 'XOR' operator", shape = (208.5, 250))
gr.HTML("<hr>")
gr.Markdown("### Furthermore, we observe that we get different decision function plots by varying the Kernel and Gamma hyperparameters of the non-linear SVC.")
gr.Markdown("### Feel free to experiment with kernel and gamma values below to see how the quality of the decision function changes with the hyperparameters.")
inp1 = gr.Radio(['poly', 'rbf', 'sigmoid'], label="Kernel", info="Choose a kernel")
inp2 = gr.Radio(['scale', 'auto'], label="Gamma", info="Choose a gamma value")
btn = gr.Button(value="Submit")
with gr.Row():
plot = gr.Plot(label=f"Decision function plot for Non-Linear SVC with the '{inp1}' kernel and '{inp2}' gamma ")
num = gr.Textbox(label="Test Accuracy")
btn.click(getColorMap, inputs=[inp1, inp2], outputs=[plot, num])
if __name__ == "__main__":
print("hdh")
demo.launch()
print("gedhhfhf")