File size: 4,180 Bytes
630caba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd26789
 
 
 
 
 
 
 
 
630caba
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from sklearn.datasets import make_multilabel_classification
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from matplotlib import cm

plt.switch_backend('agg')


def plot_hyperplane(clf, min_x, max_x, linestyle, linecolor, label):
    """
    This function is used to plot the hyperplane obtained from the classifier.

    :param clf: the classifier model
    :param min_x: the minimum value of X
    :param max_x: the maximum value of x
    :param linestyle: the style of line one needs in the plot.
    :param label: the label for the hyperplane
    """

    w = clf.coef_[0]
    a = -w[0] / w[1]
    xx = np.linspace(min_x - 5, max_x + 5)
    yy = a * xx - (clf.intercept_[0]) / w[1]
    plt.plot(xx, yy, linestyle, color=linecolor, linewidth=2.5, label=label)



def multilabel_classification(n_samples:int, n_classes: int, n_labels: int, allow_unlabeled: bool, decompostion: str) -> "plt.Figure":
    """
    This function is used to perform multilabel classification.

    :param n_samples: the number of samples.
    :param n_classes: the number of classes for the classification problem.
    :param n_labels: the average number of labels per instance.
    :param allow_unlabeled: if set to True some instances might not belong to any class.
    :param decompostion: the type of decomposition algorithm to use.

    :returns: a matplotlib figure.
    """

    X, Y = make_multilabel_classification(
    n_samples=n_samples,
    n_classes=n_classes, n_labels=n_labels, allow_unlabeled=allow_unlabeled, random_state=42)

    if decomposition == "PCA":
        X = PCA(n_components=2).fit_transform(X)

    else:
        X = CCA(n_components=2).fit(X, Y).transform(X)

    min_x = np.min(X[:, 0])
    max_x = np.max(X[:, 0])


    min_y = np.min(X[:, 1])
    max_y = np.max(X[:, 1])

    model = OneVsRestClassifier(SVC(kernel="linear"))
    model.fit(X, Y)

    fig, ax = plt.subplots(1, 1, figsize=(24, 15))

    ax.scatter(X[:, 0], X[:, 1], s=40, c="gray", edgecolors=(0, 0, 0))
    # colors = cm.rainbow(np.linspace(0, 1, n_classes))
    colors = cm.get_cmap('tab10', 10)(np.linspace(0, 1, 10))

    for nc in range(n_classes):
        cl = np.where(Y[:, nc])
        ax.scatter(X[cl, 0], X[cl, 1], s=np.random.random_integers(20, 200), 
                   edgecolors=colors[nc], facecolors="none", linewidths=2, label=f"Class {nc+1}")
        
        plot_hyperplane(model.estimators_[nc], min_x, max_x, "--", colors[nc], f"Boundary for class {nc+1}")
        ax.set_xticks(())
        ax.set_yticks(())

        ax.set_xlim(min_x - .5 * max_x, max_x + .5 * max_x)
        ax.set_ylim(min_y - .5 * max_y, max_y + .5 * max_y)

    ax.legend()
        

    return fig




with gr.Blocks() as demo:

    gr.Markdown(""" 
    
    # Multilabel Classification

    This space is an implementation of the scikit-learn document [Multilabel Classification](https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_multilabel.html#sphx-glr-auto-examples-miscellaneous-plot-multilabel-py).
    The objective of this space is to simulate a multi-label document classification problem, where the data is generated randomly. 
    
    """)

    n_samples = gr.Slider(100, 10_000, label="n_samples", info="the number of samples")
    n_classes = gr.Slider(2, 10, label="n_classes", info="the number of classes that data should have.", step=1)
    n_labels = gr.Slider(1, 10, label="n_labels", info="the average number of labels per instance", step=1)
    allow_unlabeled = gr.Checkbox(True, label="allow_unlabeled", info="If set to True some instances might not belong to any class.")
    decomposition = gr.Dropdown(['PCA', 'CCA'], label="decomposition", info="the type of decomposition algorithm to use.")
    
    output = gr.Plot(label="Plot")

    compute_btn = gr.Button("Compute")
    compute_btn.click(fn=multilabel_classification, inputs=[n_samples, n_classes, n_labels, allow_unlabeled, decomposition],
                      outputs=output, api_name="multilabel")


demo.launch()