caliex commited on
Commit
59b499d
1 Parent(s): d28f790

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.datasets import load_iris
4
+ from sklearn.svm import SVC
5
+ from sklearn.model_selection import StratifiedKFold, permutation_test_score
6
+ import numpy as np
7
+ import tempfile
8
+ import os
9
+
10
+ def run_permutation_test(display_option, kernel, random_state, n_permutations):
11
+ iris = load_iris()
12
+ X = iris.data
13
+ y = iris.target
14
+
15
+ n_uncorrelated_features = 20
16
+ rng = np.random.RandomState(seed=0)
17
+ X_rand = rng.normal(size=(X.shape[0], n_uncorrelated_features))
18
+
19
+ clf = SVC(kernel=kernel, random_state=random_state)
20
+ cv = StratifiedKFold(2, shuffle=True, random_state=0)
21
+
22
+ score_iris, perm_scores_iris, pvalue_iris = permutation_test_score(
23
+ clf, X, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
24
+ )
25
+
26
+ score_rand, perm_scores_rand, pvalue_rand = permutation_test_score(
27
+ clf, X_rand, y, scoring="accuracy", cv=cv, n_permutations=n_permutations
28
+ )
29
+
30
+ original_plot_path = None
31
+ random_plot_path = None
32
+
33
+ if display_option in ['original', 'both']:
34
+ # Original data
35
+ fig, ax = plt.subplots()
36
+ ax.hist(perm_scores_iris, bins=20, density=True)
37
+ ax.axvline(score_iris, ls="--", color="r")
38
+ score_label = f"Score on original\ndata: {score_iris:.2f}\n(p-value: {pvalue_iris:.3f})"
39
+ ax.text(0.7, 10, score_label, fontsize=12)
40
+ ax.set_xlabel("Accuracy score")
41
+ ax.set_ylabel("Probability")
42
+ original_plot_path = os.path.join(tempfile.mkdtemp(), "original_plot.png")
43
+ plt.savefig(original_plot_path)
44
+ plt.close()
45
+
46
+ if display_option in ['random', 'both']:
47
+ # Random data
48
+ fig, ax = plt.subplots()
49
+ ax.hist(perm_scores_rand, bins=20, density=True)
50
+ ax.set_xlim(0.13)
51
+ ax.axvline(score_rand, ls="--", color="r")
52
+ score_label = f"Score on original\ndata: {score_rand:.2f}\n(p-value: {pvalue_rand:.3f})"
53
+ ax.text(0.14, 7.5, score_label, fontsize=12)
54
+ ax.set_xlabel("Accuracy score")
55
+ ax.set_ylabel("Probability")
56
+ random_plot_path = os.path.join(tempfile.mkdtemp(), "random_plot.png")
57
+ plt.savefig(random_plot_path)
58
+ plt.close()
59
+
60
+ return original_plot_path, random_plot_path
61
+
62
+ iface = gr.Interface(
63
+ fn=run_permutation_test,
64
+ inputs=[
65
+ gr.inputs.Dropdown(
66
+ choices=["original", "random", "both"],
67
+ label="Display Option",
68
+ default="both"
69
+ ),
70
+ gr.inputs.Dropdown(
71
+ choices=["linear", "rbf", "poly"],
72
+ label="Kernel",
73
+ default="linear"
74
+ ),
75
+ gr.inputs.Slider(
76
+ minimum=0, maximum=10, step=1,
77
+ label="Random State",
78
+ default=7
79
+ ),
80
+ gr.inputs.Slider(
81
+ minimum=100, maximum=2000, step=100,
82
+ label="Number of Permutations",
83
+ default=1000
84
+ )
85
+ ],
86
+ outputs=["image", "image"],
87
+ title="Test with permutations the significance of a classification score",
88
+ description="This example demonstrates the use of permutation_test_score to evaluate the significance of a cross-validated score using permutations. This operation is being performed on the Iris Dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_permutation_tests_for_classification.html",
89
+ examples=[
90
+ ["both", "linear", 7, 1000],
91
+ ["original", "rbf", 3, 500],
92
+ ["random", "poly", 5, 1500]
93
+ ],
94
+ allow_flagging=False
95
+ )
96
+ iface.launch()