File size: 4,172 Bytes
6f11e8c
 
 
 
 
 
 
 
 
 
f4448ff
6f11e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4448ff
6f11e8c
 
 
 
 
 
 
 
 
 
 
f4448ff
 
6f11e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4448ff
6f11e8c
 
f4448ff
6f11e8c
 
 
 
 
 
f4448ff
6f11e8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4448ff
6f11e8c
 
 
 
f4448ff
 
 
6f11e8c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import numpy as np
import gradio as gr
from sklearn.svm import SVC
import plotly.graph_objects as go

def plot_decision(
        clf: SVC,
        X: np.ndarray,
        y: np.array,
        x_range: np.array,
        y_range: np.array,
        weights: np.array,
        title: str
    ):
    # plot the decision function
    xx, yy = np.meshgrid(x_range, y_range)

    Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)


    fig = go.Figure()

    fig.add_trace(
        go.Contour(
            x=x_range,
            y=y_range,
            z=Z,
            colorscale="Viridis",
            opacity=0.75,        
            showscale=False,
        )
    )

    fig.add_trace(
        go.Scatter(
            x=X[:, 0],
            y=X[:, 1],
            mode="markers",
            marker=dict(
                color=y,
                colorscale="viridis",
                size=(weights + 5) * 2
            ),
        )
    )

    # Remove x and y ticks
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)
    # Add title
    fig.update_layout(title=title)

    return fig

def app_fn(seed: int, weight_1: int, weight_2: int):
    # we create 20 points
    np.random.seed(seed)
    X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
    y = [1] * 10 + [-1] * 10

    sample_weight_last_ten = abs(np.random.randn(len(X)))
    sample_weight_constant = np.ones(len(X))

    sample_weight_last_ten[15:] *= weight_1
    sample_weight_last_ten[9] *= weight_2

    # This model does not take into account sample weights.
    clf_no_weights = SVC(gamma=1)
    clf_no_weights.fit(X, y)

    # This other model takes into account some dedicated sample weights.
    clf_weights = SVC(gamma=1)
    clf_weights.fit(X, y, sample_weight=sample_weight_last_ten)

    # Plotting
    x_range = np.arange(-4, 5, 0.1)

    fig_no_weights = plot_decision(
        clf_no_weights, 
        X,
        y,
        x_range, 
        x_range, 
        sample_weight_constant,  
        "SVM without Weights"
    )

    fig_weights = plot_decision(
        clf_weights,
        X,
        y,
        x_range,
        x_range,
        sample_weight_last_ten,
        "SVM with Weights"
    )

    return fig_no_weights, fig_weights

title = "SVM with Weighted Samples"

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(
        """
        ### This is a demo of how SVMs can be trained with weighted samples \
        and the impact on the decision boundary. To represent that a synthetic \
        dataset is generated with 20 points, 10 of which are assigned to the \
        positive class and 10 to the negative class. A weight is assigned to \
        each sample, which is the importance of that sample in the dataset. \
        A model with and without weights is trained and the decision boundary \
        is plotted. The size of the points is proportional to the weight of \
        the sample.

        Created by [@eduardopacheco](https://huggingface.co/EduardoPacheco) based on [scikit-learn-docs](https://scikit-learn.org/stable/auto_examples/svm/plot_weighted_samples.html#sphx-glr-auto-examples-svm-plot-weighted-samples-py)
        """
    )
    with gr.Row():
        seed = gr.inputs.Slider(0, 100, 1, default=0, label="Seed")
        weight_1 = gr.inputs.Slider(0, 20, 1, default=5, label="Weight for last 5 Samples")
        weight_2 = gr.inputs.Slider(0, 20, 1, default=15, label="Weight for Sample 10")
    # btn = gr.Button("Run")
    with gr.Row():
        fig_no_weights = gr.Plot(label="SVM without Weights")
        fig_weights = gr.Plot(label="SVM with Weights")
        
    seed.change(fn=app_fn, outputs=[fig_no_weights, fig_weights], inputs=[seed, weight_1, weight_2])
    weight_1.change(fn=app_fn, outputs=[fig_no_weights, fig_weights], inputs=[seed, weight_1, weight_2])
    weight_2.change(fn=app_fn, outputs=[fig_no_weights, fig_weights], inputs=[seed, weight_1, weight_2])
    demo.load(fn=app_fn, outputs=[fig_no_weights, fig_weights], inputs=[seed, weight_1, weight_2])

demo.launch()