dominguesm commited on
Commit
d79693f
1 Parent(s): 2976bfe

App gradio

Browse files
Files changed (2) hide show
  1. app.py +240 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ from sklearn.datasets import fetch_20newsgroups
8
+ from sklearn.feature_extraction.text import TfidfVectorizer
9
+ from sklearn.model_selection import RandomizedSearchCV
10
+ from sklearn.naive_bayes import ComplementNB
11
+ from sklearn.pipeline import Pipeline
12
+
13
+ CATEGORIES = [
14
+ "alt.atheism",
15
+ "comp.graphics",
16
+ "comp.os.ms-windows.misc",
17
+ "comp.sys.ibm.pc.hardware",
18
+ "comp.sys.mac.hardware",
19
+ "comp.windows.x",
20
+ "misc.forsale",
21
+ "rec.autos",
22
+ "rec.motorcycles",
23
+ "rec.sport.baseball",
24
+ "rec.sport.hockey",
25
+ "sci.crypt",
26
+ "sci.electronics",
27
+ "sci.med",
28
+ "sci.space",
29
+ "soc.religion.christian",
30
+ "talk.politics.guns",
31
+ "talk.politics.mideast",
32
+ "talk.politics.misc",
33
+ "talk.religion.misc",
34
+ ]
35
+
36
+
37
+ PARAMETER_GRID = {
38
+ "vect__max_df": (0.2, 0.4, 0.6, 0.8, 1.0),
39
+ "vect__min_df": (1, 3, 5, 10),
40
+ "vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams
41
+ "vect__norm": ("l1", "l2"),
42
+ "clf__alpha": np.logspace(-6, 6, 13),
43
+ }
44
+
45
+
46
+ def shorten_param(param_name):
47
+ """Remove components' prefixes in param_name."""
48
+ if "__" in param_name:
49
+ return param_name.rsplit("__", 1)[1]
50
+ return param_name
51
+
52
+
53
+ def train_model(categories):
54
+ pipeline = Pipeline(
55
+ [
56
+ ("vect", TfidfVectorizer()),
57
+ ("clf", ComplementNB()),
58
+ ]
59
+ )
60
+
61
+ data_train = fetch_20newsgroups(
62
+ subset="train",
63
+ categories=categories,
64
+ shuffle=True,
65
+ random_state=42,
66
+ remove=("headers", "footers", "quotes"),
67
+ )
68
+
69
+ data_test = fetch_20newsgroups(
70
+ subset="test",
71
+ categories=categories,
72
+ shuffle=True,
73
+ random_state=42,
74
+ remove=("headers", "footers", "quotes"),
75
+ )
76
+
77
+ pipeline = Pipeline(
78
+ [
79
+ ("vect", TfidfVectorizer()),
80
+ ("clf", ComplementNB()),
81
+ ]
82
+ )
83
+
84
+ random_search = RandomizedSearchCV(
85
+ estimator=pipeline,
86
+ param_distributions=PARAMETER_GRID,
87
+ n_iter=40,
88
+ random_state=0,
89
+ n_jobs=2,
90
+ verbose=1,
91
+ )
92
+
93
+ random_search.fit(data_train.data, data_train.target)
94
+ best_parameters = random_search.best_estimator_.get_params()
95
+
96
+ test_accuracy = random_search.score(data_test.data, data_test.target)
97
+
98
+ cv_results = pd.DataFrame(random_search.cv_results_)
99
+ cv_results = cv_results.rename(shorten_param, axis=1)
100
+
101
+ param_names = [shorten_param(name) for name in PARAMETER_GRID.keys()]
102
+ labels = {
103
+ "mean_score_time": "CV Score time (s)",
104
+ "mean_test_score": "CV score (accuracy)",
105
+ }
106
+ fig = px.scatter(
107
+ cv_results,
108
+ x="mean_score_time",
109
+ y="mean_test_score",
110
+ error_x="std_score_time",
111
+ error_y="std_test_score",
112
+ hover_data=param_names,
113
+ labels=labels,
114
+ )
115
+ fig.update_layout(
116
+ title={
117
+ "text": "trade-off between scoring time and mean test score",
118
+ "y": 0.95,
119
+ "x": 0.5,
120
+ "xanchor": "center",
121
+ "yanchor": "top",
122
+ }
123
+ )
124
+
125
+ column_results = param_names + ["mean_test_score", "mean_score_time"]
126
+
127
+ transform_funcs = dict.fromkeys(column_results, lambda x: x)
128
+ # Using a logarithmic scale for alpha
129
+ transform_funcs["alpha"] = math.log10
130
+ # L1 norms are mapped to index 1, and L2 norms to index 2
131
+ transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1
132
+ # Unigrams are mapped to index 1 and bigrams to index 2
133
+ transform_funcs["ngram_range"] = lambda x: x[1]
134
+
135
+ fig2 = px.parallel_coordinates(
136
+ cv_results[column_results].apply(transform_funcs),
137
+ color="mean_test_score",
138
+ color_continuous_scale=px.colors.sequential.Viridis_r,
139
+ labels=labels,
140
+ )
141
+ fig2.update_layout(
142
+ title={
143
+ "text": "Parallel coordinates plot of text classifier pipeline",
144
+ "y": 0.99,
145
+ "x": 0.5,
146
+ "xanchor": "center",
147
+ "yanchor": "top",
148
+ }
149
+ )
150
+
151
+ return fig, fig2, best_parameters, test_accuracy
152
+
153
+
154
+ DESCRIPTION_PART1 = [
155
+ "The dataset used in this example is",
156
+ "[The 20 newsgroups text dataset](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset)",
157
+ "which will be automatically downloaded, cached and reused for the document classification example.",
158
+ ]
159
+
160
+ DESCRIPTION_PART2 = [
161
+ "In this example, we tune the hyperparameters of",
162
+ "a particular classifier using a",
163
+ "[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV).",
164
+ "For a demo on the performance of some other classifiers, see the",
165
+ "[Classification of text documents using sparse features](https://scikit-learn.org/stable/auto_examples/text/plot_document_classification_20newsgroups.html#sphx-glr-auto-examples-text-plot-document-classification-20newsgroups-py) notebook.",
166
+ ]
167
+
168
+ AUTHOR = """
169
+ Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
170
+ """
171
+
172
+
173
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
174
+ with gr.Row():
175
+ with gr.Column():
176
+ gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
177
+ gr.Markdown(" ".join(DESCRIPTION_PART1))
178
+ gr.Markdown(" ".join(DESCRIPTION_PART2))
179
+ gr.Markdown(AUTHOR)
180
+
181
+ with gr.Row():
182
+ with gr.Column():
183
+ gr.Markdown("""## CATEGORY SELECTION""")
184
+ drop_categories = gr.Dropdown(
185
+ CATEGORIES,
186
+ value=["alt.atheism", "talk.religion.misc"],
187
+ multiselect=True,
188
+ label="Categories",
189
+ info="Select the categories you want to train on.",
190
+ max_choices=2,
191
+ interactive=True,
192
+ )
193
+ with gr.Row():
194
+ with gr.Column():
195
+ gr.Markdown(
196
+ """
197
+ ## PARAMETERS GRID
198
+ ```python
199
+ {
200
+ 'clf__alpha': array(
201
+ [1.e-06, 1.e-05, 1.e-04,...]
202
+ ),
203
+ 'vect__max_df': (0.2, 0.4, 0.6, 0.8, 1.0),
204
+ 'vect__min_df': (1, 3, 5, 10),
205
+ 'vect__ngram_range': ((1, 1), (1, 2)),
206
+ 'vect__norm': ('l1', 'l2')
207
+ }
208
+ ```
209
+ ## MODEL PIPELINE
210
+ ```python
211
+ pipeline = Pipeline(
212
+ [
213
+ ("vect", TfidfVectorizer()),
214
+ ("clf", ComplementNB()),
215
+ ]
216
+ )
217
+ ```
218
+ """
219
+ )
220
+ with gr.Row():
221
+ with gr.Column():
222
+ gr.Markdown("""## TRAINING""")
223
+ with gr.Row():
224
+ brn_train = gr.Button("Train").style(container=False)
225
+
226
+ gr.Markdown("## RESULTS")
227
+ with gr.Row():
228
+ best_parameters = gr.Textbox(label="Best parameters")
229
+ test_accuracy = gr.Textbox(label="Test accuracy")
230
+
231
+ plot_trade = gr.Plot(label="")
232
+ plot_coordinates = gr.Plot(label="")
233
+
234
+ brn_train.click(
235
+ train_model,
236
+ [drop_categories],
237
+ [plot_trade, plot_coordinates, best_parameters, test_accuracy],
238
+ )
239
+
240
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ numpy
2
+ scikit-learn
3
+ plotly
4
+ matplotlib
5
+ pandas