dominguesm commited on
Commit
3a4d722
1 Parent(s): 30ae784

Add parameter grid config

Browse files
Files changed (1) hide show
  1. app.py +85 -52
app.py CHANGED
@@ -34,15 +34,6 @@ CATEGORIES = [
34
  ]
35
 
36
 
37
- PARAMETER_GRID = {
38
- "vect__max_df": (0.2, 0.4, 0.6, 0.8, 1.0),
39
- "vect__min_df": (1, 3, 5, 10),
40
- "vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams
41
- "vect__norm": ("l1", "l2"),
42
- "clf__alpha": np.logspace(-6, 6, 13),
43
- }
44
-
45
-
46
  def shorten_param(param_name):
47
  """Remove components' prefixes in param_name."""
48
  if "__" in param_name:
@@ -50,7 +41,7 @@ def shorten_param(param_name):
50
  return param_name
51
 
52
 
53
- def train_model(categories):
54
  pipeline = Pipeline(
55
  [
56
  ("vect", TfidfVectorizer()),
@@ -58,6 +49,16 @@ def train_model(categories):
58
  ]
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
61
  data_train = fetch_20newsgroups(
62
  subset="train",
63
  categories=categories,
@@ -83,7 +84,7 @@ def train_model(categories):
83
 
84
  random_search = RandomizedSearchCV(
85
  estimator=pipeline,
86
- param_distributions=PARAMETER_GRID,
87
  n_iter=40,
88
  random_state=0,
89
  n_jobs=2,
@@ -103,7 +104,7 @@ def train_model(categories):
103
  cv_results = pd.DataFrame(random_search.cv_results_)
104
  cv_results = cv_results.rename(shorten_param, axis=1)
105
 
106
- param_names = [shorten_param(name) for name in PARAMETER_GRID.keys()]
107
  labels = {
108
  "mean_score_time": "CV Score time (s)",
109
  "mean_test_score": "CV score (accuracy)",
@@ -156,28 +157,10 @@ def train_model(categories):
156
  return fig, fig2, best_parameters, test_accuracy
157
 
158
 
159
- DESCRIPTION_PART1 = [
160
- "The dataset used in this example is",
161
- "[The 20 newsgroups text dataset](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset)",
162
- "which will be automatically downloaded, cached and reused for the document classification example.",
163
- ]
164
 
165
- DESCRIPTION_PART2 = [
166
- "In this example, we tune the hyperparameters of",
167
- "a particular classifier using a",
168
- "[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV).",
169
- "For a demo on the performance of some other classifiers, see the",
170
- "[Classification of text documents using sparse features](https://scikit-learn.org/stable/auto_examples/text/plot_document_classification_20newsgroups.html#sphx-glr-auto-examples-text-plot-document-classification-20newsgroups-py) notebook.",
171
- ]
172
-
173
- CATEGORY_SELECTION_DESCRIPTION = [
174
- "The task of text classification is easier when there is little overlap between the characteristic terms ",
175
- "of different topics. This is because the presence of common terms can make it difficult to distinguish between ",
176
- "different topics. On the other hand, when there is little overlap between the characteristic terms of different ",
177
- "topics, the task of text classification becomes easier, as the unique terms of each topic provide a solid basis ",
178
- "for accurately classifying the document into its respective category. Therefore, careful selection of characteristic",
179
- " terms for each topic is crucial to ensure accuracy in text classification."
180
- ]
181
 
182
  AUTHOR = """
183
  Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
@@ -188,14 +171,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
188
  with gr.Row():
189
  with gr.Column():
190
  gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
191
- gr.Markdown(" ".join(DESCRIPTION_PART1))
192
- gr.Markdown(" ".join(DESCRIPTION_PART2))
193
  gr.Markdown(AUTHOR)
194
 
195
  with gr.Row():
196
  with gr.Column():
197
  gr.Markdown("""## CATEGORY SELECTION""")
198
- gr.Markdown("".join(CATEGORY_SELECTION_DESCRIPTION))
199
  drop_categories = gr.Dropdown(
200
  CATEGORIES,
201
  value=["alt.atheism", "talk.religion.misc"],
@@ -207,20 +190,70 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
207
  )
208
  with gr.Row():
209
  with gr.Column():
210
- gr.Markdown(
211
- """
212
- ## PARAMETERS GRID
213
- ```python
214
- {
215
- 'clf__alpha': array(
216
- [1.e-06, 1.e-05, 1.e-04,...]
217
- ),
218
- 'vect__max_df': (0.2, 0.4, 0.6, 0.8, 1.0),
219
- 'vect__min_df': (1, 3, 5, 10),
220
- 'vect__ngram_range': ((1, 1), (1, 2)),
221
- 'vect__norm': ('l1', 'l2')
222
- }
223
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  ## MODEL PIPELINE
225
  ```python
226
  pipeline = Pipeline(
@@ -231,7 +264,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
231
  )
232
  ```
233
  """
234
- )
235
  with gr.Row():
236
  with gr.Column():
237
  gr.Markdown("""## TRAINING""")
@@ -248,7 +281,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
248
 
249
  brn_train.click(
250
  train_model,
251
- [drop_categories],
252
  [plot_trade, plot_coordinates, best_parameters, test_accuracy],
253
  )
254
 
 
34
  ]
35
 
36
 
 
 
 
 
 
 
 
 
 
37
  def shorten_param(param_name):
38
  """Remove components' prefixes in param_name."""
39
  if "__" in param_name:
 
41
  return param_name
42
 
43
 
44
+ def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm):
45
  pipeline = Pipeline(
46
  [
47
  ("vect", TfidfVectorizer()),
 
49
  ]
50
  )
51
 
52
+ parameters_grid = {
53
+ "vect__max_df": [eval(value) for value in vect__max_df.split(",")],
54
+ "vect__min_df": [eval(value) for value in vect__min_df.split(",")],
55
+ "vect__ngram_range": eval(vect__ngram_range), # unigrams or bigrams
56
+ "vect__norm": [value.strip() for value in vect__norm.split(",")],
57
+ "clf__alpha": np.logspace(-6, 6, 13),
58
+ }
59
+
60
+ print(parameters_grid)
61
+
62
  data_train = fetch_20newsgroups(
63
  subset="train",
64
  categories=categories,
 
84
 
85
  random_search = RandomizedSearchCV(
86
  estimator=pipeline,
87
+ param_distributions=parameters_grid,
88
  n_iter=40,
89
  random_state=0,
90
  n_jobs=2,
 
104
  cv_results = pd.DataFrame(random_search.cv_results_)
105
  cv_results = cv_results.rename(shorten_param, axis=1)
106
 
107
+ param_names = [shorten_param(name) for name in parameters_grid.keys()]
108
  labels = {
109
  "mean_score_time": "CV Score time (s)",
110
  "mean_test_score": "CV score (accuracy)",
 
157
  return fig, fig2, best_parameters, test_accuracy
158
 
159
 
160
+ def load_description(name):
161
+ with open(f"./descriptions/{name}.md", "r") as f:
162
+ return f.read()
 
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  AUTHOR = """
166
  Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
 
171
  with gr.Row():
172
  with gr.Column():
173
  gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
174
+ gr.Markdown(load_description("description_part1"))
175
+ gr.Markdown(load_description("description_part2"))
176
  gr.Markdown(AUTHOR)
177
 
178
  with gr.Row():
179
  with gr.Column():
180
  gr.Markdown("""## CATEGORY SELECTION""")
181
+ gr.Markdown(load_description("description_category_selection"))
182
  drop_categories = gr.Dropdown(
183
  CATEGORIES,
184
  value=["alt.atheism", "talk.religion.misc"],
 
190
  )
191
  with gr.Row():
192
  with gr.Column():
193
+ gr.Markdown("""## PARAMETERS GRID""")
194
+ gr.Markdown(load_description("description_parameter_grid"))
195
+ with gr.Column():
196
+ gr.Markdown("""### Classifier Alpha""")
197
+ gr.Markdown(load_description("parameter_grid/alpha"))
198
+
199
+ clf__alpha = gr.Textbox(
200
+ label="clf__alpha",
201
+ value="1.e-06, 1.e-05, 1.e-04",
202
+ info="Due to practical considerations, this parameter was kept constant.",
203
+ interactive=False,
204
+ )
205
+
206
+ with gr.Column():
207
+ gr.Markdown("""### Vectorizer max_df""")
208
+ gr.Markdown(load_description("parameter_grid/max_df"))
209
+
210
+ vect__max_df = gr.Textbox(
211
+ label="vect__max_df",
212
+ value="0.2, 0.4, 0.6, 0.8, 1.0",
213
+ info="Values ranging from 0 to 1.0, separated by a comma.",
214
+ interactive=True,
215
+ )
216
+
217
+ with gr.Column():
218
+ gr.Markdown("""### Vectorizer min_df""")
219
+ gr.Markdown(load_description("parameter_grid/min_df"))
220
+
221
+ vect__min_df = gr.Textbox(
222
+ label="vect__min_df",
223
+ value="1, 3, 5, 10",
224
+ info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.",
225
+ interactive=True,
226
+ )
227
+ with gr.Column():
228
+ gr.Markdown("""### Vectorizer ngram_range""")
229
+ gr.Markdown(load_description("parameter_grid/ngram_range"))
230
+
231
+ vect__ngram_range = gr.Textbox(
232
+ label="vect__ngram_range",
233
+ value="(1, 1), (1, 2)",
234
+ info="""Tuples of integer values separated by a comma. For example an ``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means unigrams and bigrams, and ``(2, 2)`` means only bigrams.""",
235
+ interactive=True,
236
+ )
237
+ with gr.Column():
238
+ gr.Markdown("""### Vectorizer norm""")
239
+ gr.Markdown(load_description("parameter_grid/norm"))
240
+ gr.Markdown(
241
+ """- 'l2': Sum of squares of vector elements is 1. The cosine
242
+ similarity between two vectors is their dot product when l2 norm has
243
+ been applied.
244
+ - 'l1': Sum of absolute values of vector elements is 1."""
245
+ )
246
+
247
+ vect__norm = gr.Textbox(
248
+ label="vect__norm",
249
+ value="l1, l2",
250
+ info="'l1' or 'l2', separated by a comma",
251
+ interactive=True,
252
+ )
253
+
254
+ with gr.Row():
255
+ gr.Markdown(
256
+ """
257
  ## MODEL PIPELINE
258
  ```python
259
  pipeline = Pipeline(
 
264
  )
265
  ```
266
  """
267
+ )
268
  with gr.Row():
269
  with gr.Column():
270
  gr.Markdown("""## TRAINING""")
 
281
 
282
  brn_train.click(
283
  train_model,
284
+ [drop_categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm],
285
  [plot_trade, plot_coordinates, best_parameters, test_accuracy],
286
  )
287