Nahrawy's picture
Update app.py
ad3fd13
raw
history blame
2.88 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model
def plot(seed, num_points):
# Error handling of non-numeric seeds
if seed and not seed.isnumeric():
raise gr.Error("Invalid seed")
# Setting the seed
if seed:
seed = int(seed)
np.random.seed(seed)
num_points = int(num_points)
#Ensuring the number of points is even
if num_points%2 != 0:
num_points +=1
half_num_points = int(num_points/2)
X = np.r_[np.random.randn(half_num_points, 2) + [1, 1], np.random.randn(half_num_points, 2)]
y = [1] * half_num_points + [-1] * half_num_points
sample_weight = 100 * np.abs(np.random.randn(num_points))
# and assign a bigger weight to the last 10 samples
sample_weight[:half_num_points] *= 10
# plot the weighted data points
xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))
fig, ax = plt.subplots()
ax.scatter(
X[:, 0],
X[:, 1],
c=y,
s=sample_weight,
alpha=0.9,
cmap=plt.cm.bone,
edgecolor="black",
)
# fit the unweighted model
clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
clf.fit(X, y)
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
no_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["solid"])
# fit the weighted model
clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
clf.fit(X, y, sample_weight=sample_weight)
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
samples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["dashed"])
no_weights_handles, _ = no_weights.legend_elements()
weights_handles, _ = samples_weights.legend_elements()
ax.legend(
[no_weights_handles[0], weights_handles[0]],
["no weights", "with weights"],
loc="lower left",
)
ax.set(xticks=(), yticks=())
return fig
info = ''' # SGD: Weighted samples\n
This is a demonstration of a modified version of [SGD](https://scikit-learn.org/stable/modules/sgd.html#id5) that takes into account the weights of the samples. Where the size of points is proportional to its weight.\n
Created by [@Nahrawy](https://huggingface.co/Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_weighted_samples.html).
'''
with gr.Blocks() as demo:
gr.Markdown(info)
with gr.Row():
with gr.Column():
seed = gr.Textbox(label="Seed", info="Leave empty to generate new random points each run ",value=None)
num_points = gr.Slider(label="Number of Points", value="20", minimum=5, maximum=100, step=2)
btn = gr.Button("Run")
out = gr.Plot()
btn.click(fn=plot, inputs=[seed,num_points] , outputs=out)
demo.launch()