caliex commited on
Commit
353cba6
1 Parent(s): 28f35c4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from matplotlib import pyplot as plt
4
+ from sklearn.datasets import make_hastie_10_2
5
+ from sklearn.model_selection import GridSearchCV
6
+ from sklearn.metrics import make_scorer, accuracy_score
7
+ from sklearn.tree import DecisionTreeClassifier
8
+ import tempfile
9
+
10
+
11
+ def grid_search(min_samples_split, max_depth):
12
+ X, y = make_hastie_10_2(n_samples=8000, random_state=42)
13
+ scoring = {"AUC": "roc_auc", "Accuracy": make_scorer(accuracy_score)}
14
+
15
+ gs = GridSearchCV(
16
+ DecisionTreeClassifier(random_state=42),
17
+ param_grid={"min_samples_split": range(min_samples_split, max_depth + 1, 20)},
18
+ scoring=scoring,
19
+ refit="AUC",
20
+ n_jobs=2,
21
+ return_train_score=True,
22
+ )
23
+ gs.fit(X, y)
24
+ results = gs.cv_results_
25
+
26
+ plt.figure(figsize=(13, 13))
27
+ plt.title("GridSearchCV evaluating using multiple scorers simultaneously", fontsize=16)
28
+ plt.xlabel("min_samples_split")
29
+ plt.ylabel("Score")
30
+ ax = plt.gca()
31
+ ax.set_xlim(min_samples_split, max_depth)
32
+ ax.set_ylim(0.73, 1)
33
+
34
+ X_axis = np.array(results["param_min_samples_split"].data, dtype=float)
35
+
36
+ for scorer, color in zip(sorted(scoring), ["g", "k"]):
37
+ for sample, style in (("train", "--"), ("test", "-")):
38
+ sample_score_mean = results["mean_%s_%s" % (sample, scorer)]
39
+ sample_score_std = results["std_%s_%s" % (sample, scorer)]
40
+ ax.fill_between(
41
+ X_axis,
42
+ sample_score_mean - sample_score_std,
43
+ sample_score_mean + sample_score_std,
44
+ alpha=0.1 if sample == "test" else 0,
45
+ color=color,
46
+ )
47
+ ax.plot(
48
+ X_axis,
49
+ sample_score_mean,
50
+ style,
51
+ color=color,
52
+ alpha=1 if sample == "test" else 0.7,
53
+ label="%s (%s)" % (scorer, sample),
54
+ )
55
+
56
+ best_index = np.nonzero(results["rank_test_%s" % scorer] == 1)[0][0]
57
+ best_score = results["mean_test_%s" % scorer][best_index]
58
+
59
+ ax.plot(
60
+ [
61
+ X_axis[best_index],
62
+ ]
63
+ * 2,
64
+ [0, best_score],
65
+ linestyle="-.",
66
+ color=color,
67
+ marker="x",
68
+ markeredgewidth=3,
69
+ ms=8,
70
+ )
71
+
72
+ ax.annotate("%0.2f" % best_score, (X_axis[best_index], best_score + 0.005))
73
+
74
+ plt.legend(loc="best")
75
+ plt.grid(False)
76
+
77
+ # Save the plot as an image file
78
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
79
+ temp_filename = temp_file.name
80
+ plt.savefig(temp_filename)
81
+
82
+ # Return the path to the image file
83
+ return temp_filename
84
+
85
+
86
+ min_samples_split_input = gr.inputs.Slider(minimum=2, maximum=402, default=2, step=20, label="min_samples_split")
87
+ max_depth_input = gr.inputs.Slider(minimum=2, maximum=402, default=402, step=20, label="max_depth")
88
+ outputs = gr.outputs.Image(type="pil", label="Score Plot")
89
+
90
+ title = "Multi-Metric Evaluation on Cross_Val_Score and GridSearchCV"
91
+ description = "This app allows users to explore the performance of a Decision Tree Classifier by adjusting the parameters 'min_samples_split' and 'max_depth'. The app performs a grid search and evaluates the classifier using multiple scoring metrics. The resulting score plot provides insights into the impact of parameter variations on model performance. Users can interactively modify the parameter values using sliders and observe the corresponding changes in the score plot. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_multi_metric_evaluation.html"
92
+ examples = [
93
+ [42, 402],
94
+ [130, 340],
95
+ [88, 240],
96
+ ]
97
+
98
+ gr.Interface(fn=grid_search, inputs=[min_samples_split_input, max_depth_input], outputs=outputs,
99
+ title=title, description=description, examples=examples).launch()