Commit
·
3a4d722
1
Parent(s):
30ae784
Add parameter grid config
Browse files
app.py
CHANGED
|
@@ -34,15 +34,6 @@ CATEGORIES = [
|
|
| 34 |
]
|
| 35 |
|
| 36 |
|
| 37 |
-
PARAMETER_GRID = {
|
| 38 |
-
"vect__max_df": (0.2, 0.4, 0.6, 0.8, 1.0),
|
| 39 |
-
"vect__min_df": (1, 3, 5, 10),
|
| 40 |
-
"vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams
|
| 41 |
-
"vect__norm": ("l1", "l2"),
|
| 42 |
-
"clf__alpha": np.logspace(-6, 6, 13),
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
|
| 46 |
def shorten_param(param_name):
|
| 47 |
"""Remove components' prefixes in param_name."""
|
| 48 |
if "__" in param_name:
|
|
@@ -50,7 +41,7 @@ def shorten_param(param_name):
|
|
| 50 |
return param_name
|
| 51 |
|
| 52 |
|
| 53 |
-
def train_model(categories):
|
| 54 |
pipeline = Pipeline(
|
| 55 |
[
|
| 56 |
("vect", TfidfVectorizer()),
|
|
@@ -58,6 +49,16 @@ def train_model(categories):
|
|
| 58 |
]
|
| 59 |
)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
data_train = fetch_20newsgroups(
|
| 62 |
subset="train",
|
| 63 |
categories=categories,
|
|
@@ -83,7 +84,7 @@ def train_model(categories):
|
|
| 83 |
|
| 84 |
random_search = RandomizedSearchCV(
|
| 85 |
estimator=pipeline,
|
| 86 |
-
param_distributions=
|
| 87 |
n_iter=40,
|
| 88 |
random_state=0,
|
| 89 |
n_jobs=2,
|
|
@@ -103,7 +104,7 @@ def train_model(categories):
|
|
| 103 |
cv_results = pd.DataFrame(random_search.cv_results_)
|
| 104 |
cv_results = cv_results.rename(shorten_param, axis=1)
|
| 105 |
|
| 106 |
-
param_names = [shorten_param(name) for name in
|
| 107 |
labels = {
|
| 108 |
"mean_score_time": "CV Score time (s)",
|
| 109 |
"mean_test_score": "CV score (accuracy)",
|
|
@@ -156,28 +157,10 @@ def train_model(categories):
|
|
| 156 |
return fig, fig2, best_parameters, test_accuracy
|
| 157 |
|
| 158 |
|
| 159 |
-
|
| 160 |
-
"
|
| 161 |
-
|
| 162 |
-
"which will be automatically downloaded, cached and reused for the document classification example.",
|
| 163 |
-
]
|
| 164 |
|
| 165 |
-
DESCRIPTION_PART2 = [
|
| 166 |
-
"In this example, we tune the hyperparameters of",
|
| 167 |
-
"a particular classifier using a",
|
| 168 |
-
"[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV).",
|
| 169 |
-
"For a demo on the performance of some other classifiers, see the",
|
| 170 |
-
"[Classification of text documents using sparse features](https://scikit-learn.org/stable/auto_examples/text/plot_document_classification_20newsgroups.html#sphx-glr-auto-examples-text-plot-document-classification-20newsgroups-py) notebook.",
|
| 171 |
-
]
|
| 172 |
-
|
| 173 |
-
CATEGORY_SELECTION_DESCRIPTION = [
|
| 174 |
-
"The task of text classification is easier when there is little overlap between the characteristic terms ",
|
| 175 |
-
"of different topics. This is because the presence of common terms can make it difficult to distinguish between ",
|
| 176 |
-
"different topics. On the other hand, when there is little overlap between the characteristic terms of different ",
|
| 177 |
-
"topics, the task of text classification becomes easier, as the unique terms of each topic provide a solid basis ",
|
| 178 |
-
"for accurately classifying the document into its respective category. Therefore, careful selection of characteristic",
|
| 179 |
-
" terms for each topic is crucial to ensure accuracy in text classification."
|
| 180 |
-
]
|
| 181 |
|
| 182 |
AUTHOR = """
|
| 183 |
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
|
|
@@ -188,14 +171,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
| 188 |
with gr.Row():
|
| 189 |
with gr.Column():
|
| 190 |
gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
|
| 191 |
-
gr.Markdown("
|
| 192 |
-
gr.Markdown("
|
| 193 |
gr.Markdown(AUTHOR)
|
| 194 |
|
| 195 |
with gr.Row():
|
| 196 |
with gr.Column():
|
| 197 |
gr.Markdown("""## CATEGORY SELECTION""")
|
| 198 |
-
gr.Markdown(""
|
| 199 |
drop_categories = gr.Dropdown(
|
| 200 |
CATEGORIES,
|
| 201 |
value=["alt.atheism", "talk.religion.misc"],
|
|
@@ -207,20 +190,70 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
| 207 |
)
|
| 208 |
with gr.Row():
|
| 209 |
with gr.Column():
|
| 210 |
-
gr.Markdown(
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
## MODEL PIPELINE
|
| 225 |
```python
|
| 226 |
pipeline = Pipeline(
|
|
@@ -231,7 +264,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
| 231 |
)
|
| 232 |
```
|
| 233 |
"""
|
| 234 |
-
|
| 235 |
with gr.Row():
|
| 236 |
with gr.Column():
|
| 237 |
gr.Markdown("""## TRAINING""")
|
|
@@ -248,7 +281,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
| 248 |
|
| 249 |
brn_train.click(
|
| 250 |
train_model,
|
| 251 |
-
[drop_categories],
|
| 252 |
[plot_trade, plot_coordinates, best_parameters, test_accuracy],
|
| 253 |
)
|
| 254 |
|
|
|
|
| 34 |
]
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def shorten_param(param_name):
|
| 38 |
"""Remove components' prefixes in param_name."""
|
| 39 |
if "__" in param_name:
|
|
|
|
| 41 |
return param_name
|
| 42 |
|
| 43 |
|
| 44 |
+
def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm):
|
| 45 |
pipeline = Pipeline(
|
| 46 |
[
|
| 47 |
("vect", TfidfVectorizer()),
|
|
|
|
| 49 |
]
|
| 50 |
)
|
| 51 |
|
| 52 |
+
parameters_grid = {
|
| 53 |
+
"vect__max_df": [eval(value) for value in vect__max_df.split(",")],
|
| 54 |
+
"vect__min_df": [eval(value) for value in vect__min_df.split(",")],
|
| 55 |
+
"vect__ngram_range": eval(vect__ngram_range), # unigrams or bigrams
|
| 56 |
+
"vect__norm": [value.strip() for value in vect__norm.split(",")],
|
| 57 |
+
"clf__alpha": np.logspace(-6, 6, 13),
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
print(parameters_grid)
|
| 61 |
+
|
| 62 |
data_train = fetch_20newsgroups(
|
| 63 |
subset="train",
|
| 64 |
categories=categories,
|
|
|
|
| 84 |
|
| 85 |
random_search = RandomizedSearchCV(
|
| 86 |
estimator=pipeline,
|
| 87 |
+
param_distributions=parameters_grid,
|
| 88 |
n_iter=40,
|
| 89 |
random_state=0,
|
| 90 |
n_jobs=2,
|
|
|
|
| 104 |
cv_results = pd.DataFrame(random_search.cv_results_)
|
| 105 |
cv_results = cv_results.rename(shorten_param, axis=1)
|
| 106 |
|
| 107 |
+
param_names = [shorten_param(name) for name in parameters_grid.keys()]
|
| 108 |
labels = {
|
| 109 |
"mean_score_time": "CV Score time (s)",
|
| 110 |
"mean_test_score": "CV score (accuracy)",
|
|
|
|
| 157 |
return fig, fig2, best_parameters, test_accuracy
|
| 158 |
|
| 159 |
|
| 160 |
+
def load_description(name):
|
| 161 |
+
with open(f"./descriptions/{name}.md", "r") as f:
|
| 162 |
+
return f.read()
|
|
|
|
|
|
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
AUTHOR = """
|
| 166 |
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
|
|
|
|
| 171 |
with gr.Row():
|
| 172 |
with gr.Column():
|
| 173 |
gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
|
| 174 |
+
gr.Markdown(load_description("description_part1"))
|
| 175 |
+
gr.Markdown(load_description("description_part2"))
|
| 176 |
gr.Markdown(AUTHOR)
|
| 177 |
|
| 178 |
with gr.Row():
|
| 179 |
with gr.Column():
|
| 180 |
gr.Markdown("""## CATEGORY SELECTION""")
|
| 181 |
+
gr.Markdown(load_description("description_category_selection"))
|
| 182 |
drop_categories = gr.Dropdown(
|
| 183 |
CATEGORIES,
|
| 184 |
value=["alt.atheism", "talk.religion.misc"],
|
|
|
|
| 190 |
)
|
| 191 |
with gr.Row():
|
| 192 |
with gr.Column():
|
| 193 |
+
gr.Markdown("""## PARAMETERS GRID""")
|
| 194 |
+
gr.Markdown(load_description("description_parameter_grid"))
|
| 195 |
+
with gr.Column():
|
| 196 |
+
gr.Markdown("""### Classifier Alpha""")
|
| 197 |
+
gr.Markdown(load_description("parameter_grid/alpha"))
|
| 198 |
+
|
| 199 |
+
clf__alpha = gr.Textbox(
|
| 200 |
+
label="clf__alpha",
|
| 201 |
+
value="1.e-06, 1.e-05, 1.e-04",
|
| 202 |
+
info="Due to practical considerations, this parameter was kept constant.",
|
| 203 |
+
interactive=False,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
with gr.Column():
|
| 207 |
+
gr.Markdown("""### Vectorizer max_df""")
|
| 208 |
+
gr.Markdown(load_description("parameter_grid/max_df"))
|
| 209 |
+
|
| 210 |
+
vect__max_df = gr.Textbox(
|
| 211 |
+
label="vect__max_df",
|
| 212 |
+
value="0.2, 0.4, 0.6, 0.8, 1.0",
|
| 213 |
+
info="Values ranging from 0 to 1.0, separated by a comma.",
|
| 214 |
+
interactive=True,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Column():
|
| 218 |
+
gr.Markdown("""### Vectorizer min_df""")
|
| 219 |
+
gr.Markdown(load_description("parameter_grid/min_df"))
|
| 220 |
+
|
| 221 |
+
vect__min_df = gr.Textbox(
|
| 222 |
+
label="vect__min_df",
|
| 223 |
+
value="1, 3, 5, 10",
|
| 224 |
+
info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.",
|
| 225 |
+
interactive=True,
|
| 226 |
+
)
|
| 227 |
+
with gr.Column():
|
| 228 |
+
gr.Markdown("""### Vectorizer ngram_range""")
|
| 229 |
+
gr.Markdown(load_description("parameter_grid/ngram_range"))
|
| 230 |
+
|
| 231 |
+
vect__ngram_range = gr.Textbox(
|
| 232 |
+
label="vect__ngram_range",
|
| 233 |
+
value="(1, 1), (1, 2)",
|
| 234 |
+
info="""Tuples of integer values separated by a comma. For example an ``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means unigrams and bigrams, and ``(2, 2)`` means only bigrams.""",
|
| 235 |
+
interactive=True,
|
| 236 |
+
)
|
| 237 |
+
with gr.Column():
|
| 238 |
+
gr.Markdown("""### Vectorizer norm""")
|
| 239 |
+
gr.Markdown(load_description("parameter_grid/norm"))
|
| 240 |
+
gr.Markdown(
|
| 241 |
+
"""- 'l2': Sum of squares of vector elements is 1. The cosine
|
| 242 |
+
similarity between two vectors is their dot product when l2 norm has
|
| 243 |
+
been applied.
|
| 244 |
+
- 'l1': Sum of absolute values of vector elements is 1."""
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
vect__norm = gr.Textbox(
|
| 248 |
+
label="vect__norm",
|
| 249 |
+
value="l1, l2",
|
| 250 |
+
info="'l1' or 'l2', separated by a comma",
|
| 251 |
+
interactive=True,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
with gr.Row():
|
| 255 |
+
gr.Markdown(
|
| 256 |
+
"""
|
| 257 |
## MODEL PIPELINE
|
| 258 |
```python
|
| 259 |
pipeline = Pipeline(
|
|
|
|
| 264 |
)
|
| 265 |
```
|
| 266 |
"""
|
| 267 |
+
)
|
| 268 |
with gr.Row():
|
| 269 |
with gr.Column():
|
| 270 |
gr.Markdown("""## TRAINING""")
|
|
|
|
| 281 |
|
| 282 |
brn_train.click(
|
| 283 |
train_model,
|
| 284 |
+
[drop_categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm],
|
| 285 |
[plot_trade, plot_coordinates, best_parameters, test_accuracy],
|
| 286 |
)
|
| 287 |
|