Nahrawy commited on
Commit
9b56006
1 Parent(s): 38f9c4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from sklearn import linear_model
5
+
6
+ def plot(seed, num_points):
7
+ # Error handling of non-numeric seeds
8
+ if seed and not seed.isnumeric():
9
+ raise gr.Error("Invalid seed")
10
+
11
+ # Setting the seed
12
+ if seed:
13
+ seed = int(seed)
14
+ np.random.seed(seed)
15
+ num_points = int(num_points)
16
+
17
+ #Ensuring the number of points is even
18
+ if num_points%2 != 0:
19
+ num_points +=1
20
+ half_num_points = int(num_points/2)
21
+
22
+ X = np.r_[np.random.randn(half_num_points, 2) + [1, 1], np.random.randn(half_num_points, 2)]
23
+ y = [1] * half_num_points + [-1] * half_num_points
24
+ sample_weight = 100 * np.abs(np.random.randn(num_points))
25
+ # and assign a bigger weight to the last 10 samples
26
+ sample_weight[:half_num_points] *= 10
27
+
28
+ # plot the weighted data points
29
+ xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))
30
+ fig, ax = plt.subplots()
31
+ ax.scatter(
32
+ X[:, 0],
33
+ X[:, 1],
34
+ c=y,
35
+ s=sample_weight,
36
+ alpha=0.9,
37
+ cmap=plt.cm.bone,
38
+ edgecolor="black",
39
+ )
40
+
41
+ # fit the unweighted model
42
+ clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
43
+ clf.fit(X, y)
44
+ Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
45
+ Z = Z.reshape(xx.shape)
46
+ no_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["solid"])
47
+
48
+ # fit the weighted model
49
+ clf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)
50
+ clf.fit(X, y, sample_weight=sample_weight)
51
+ Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
52
+ Z = Z.reshape(xx.shape)
53
+ samples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=["dashed"])
54
+
55
+ no_weights_handles, _ = no_weights.legend_elements()
56
+ weights_handles, _ = samples_weights.legend_elements()
57
+ ax.legend(
58
+ [no_weights_handles[0], weights_handles[0]],
59
+ ["no weights", "with weights"],
60
+ loc="lower left",
61
+ )
62
+
63
+ ax.set(xticks=(), yticks=())
64
+ return fig
65
+
66
+ info = ''' # SGD: Weighted samples\n
67
+ This is a demonstration of a modified version of [SGD](https://scikit-learn.org/stable/modules/sgd.html#id5) that takes into account the weights of the samples. Where the size of points is proportional to its weight.\n
68
+ Created by [@Nahrawy](https://huggingface.co/Nahrawy) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/linear_model/plot_sgd_weighted_samples.html).
69
+ '''
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown(info)
73
+ with gr.Row():
74
+ with gr.Column():
75
+ seed = gr.Textbox(label="Seed", info="Leave empty to generate new random points each run ",value=None)
76
+ num_points = gr.Slider(label="Number of Points", value="20", minimum=5, maximum=100, step=2)
77
+ btn = gr.Button("Run")
78
+ out = gr.Plot()
79
+ btn.click(fn=plot, inputs=[seed,num_points] , outputs=out)
80
+ demo.launch()