|
import json |
|
import math |
|
import os |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
from sklearn.datasets import fetch_20newsgroups |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.model_selection import RandomizedSearchCV |
|
from sklearn.naive_bayes import ComplementNB |
|
from sklearn.pipeline import Pipeline |
|
|
|
os.system("pip uninstall -y gradio") |
|
os.system("pip install gradio==3.26.0") |
|
|
|
|
|
CATEGORIES = [ |
|
"alt.atheism", |
|
"comp.graphics", |
|
"comp.os.ms-windows.misc", |
|
"comp.sys.ibm.pc.hardware", |
|
"comp.sys.mac.hardware", |
|
"comp.windows.x", |
|
"misc.forsale", |
|
"rec.autos", |
|
"rec.motorcycles", |
|
"rec.sport.baseball", |
|
"rec.sport.hockey", |
|
"sci.crypt", |
|
"sci.electronics", |
|
"sci.med", |
|
"sci.space", |
|
"soc.religion.christian", |
|
"talk.politics.guns", |
|
"talk.politics.mideast", |
|
"talk.politics.misc", |
|
"talk.religion.misc", |
|
] |
|
|
|
|
|
def shorten_param(param_name): |
|
"""Remove components' prefixes in param_name.""" |
|
if "__" in param_name: |
|
return param_name.rsplit("__", 1)[1] |
|
return param_name |
|
|
|
|
|
def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm): |
|
pipeline = Pipeline( |
|
[ |
|
("vect", TfidfVectorizer()), |
|
("clf", ComplementNB()), |
|
] |
|
) |
|
|
|
parameters_grid = { |
|
"vect__max_df": [eval(value) for value in vect__max_df.split(",")], |
|
"vect__min_df": [eval(value) for value in vect__min_df.split(",")], |
|
"vect__ngram_range": eval(vect__ngram_range), |
|
"vect__norm": [value.strip() for value in vect__norm.split(",")], |
|
"clf__alpha": np.logspace(-6, 6, 13), |
|
} |
|
|
|
print(parameters_grid) |
|
|
|
data_train = fetch_20newsgroups( |
|
subset="train", |
|
categories=categories, |
|
shuffle=True, |
|
random_state=42, |
|
remove=("headers", "footers", "quotes"), |
|
) |
|
|
|
data_test = fetch_20newsgroups( |
|
subset="test", |
|
categories=categories, |
|
shuffle=True, |
|
random_state=42, |
|
remove=("headers", "footers", "quotes"), |
|
) |
|
|
|
pipeline = Pipeline( |
|
[ |
|
("vect", TfidfVectorizer()), |
|
("clf", ComplementNB()), |
|
] |
|
) |
|
|
|
random_search = RandomizedSearchCV( |
|
estimator=pipeline, |
|
param_distributions=parameters_grid, |
|
n_iter=40, |
|
random_state=0, |
|
n_jobs=2, |
|
verbose=1, |
|
) |
|
|
|
random_search.fit(data_train.data, data_train.target) |
|
best_parameters = json.dumps( |
|
random_search.best_estimator_.get_params(), |
|
indent=4, |
|
sort_keys=True, |
|
default=str, |
|
) |
|
|
|
test_accuracy = random_search.score(data_test.data, data_test.target) |
|
|
|
cv_results = pd.DataFrame(random_search.cv_results_) |
|
cv_results = cv_results.rename(shorten_param, axis=1) |
|
|
|
param_names = [shorten_param(name) for name in parameters_grid.keys()] |
|
labels = { |
|
"mean_score_time": "CV Score time (s)", |
|
"mean_test_score": "CV score (accuracy)", |
|
} |
|
fig = px.scatter( |
|
cv_results, |
|
x="mean_score_time", |
|
y="mean_test_score", |
|
error_x="std_score_time", |
|
error_y="std_test_score", |
|
hover_data=param_names, |
|
labels=labels, |
|
) |
|
fig.update_layout( |
|
title={ |
|
"text": "trade-off between scoring time and mean test score", |
|
"y": 0.95, |
|
"x": 0.5, |
|
"xanchor": "center", |
|
"yanchor": "top", |
|
} |
|
) |
|
|
|
column_results = param_names + ["mean_test_score", "mean_score_time"] |
|
|
|
transform_funcs = dict.fromkeys(column_results, lambda x: x) |
|
|
|
transform_funcs["alpha"] = math.log10 |
|
|
|
transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1 |
|
|
|
transform_funcs["ngram_range"] = lambda x: x[1] |
|
|
|
fig2 = px.parallel_coordinates( |
|
cv_results[column_results].apply(transform_funcs), |
|
color="mean_test_score", |
|
color_continuous_scale=px.colors.sequential.Viridis_r, |
|
labels=labels, |
|
) |
|
fig2.update_layout( |
|
title={ |
|
"text": "Parallel coordinates plot of text classifier pipeline", |
|
"y": 0.99, |
|
"x": 0.5, |
|
"xanchor": "center", |
|
"yanchor": "top", |
|
} |
|
) |
|
|
|
return fig, fig2, best_parameters, test_accuracy |
|
|
|
|
|
def load_description(name): |
|
with open(f"./descriptions/{name}.md", "r") as f: |
|
return f.read() |
|
|
|
|
|
AUTHOR = """ |
|
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html) |
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as app: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("# Sample pipeline for text feature extraction and evaluation") |
|
gr.Markdown(load_description("description_part1")) |
|
gr.Markdown(load_description("description_part2")) |
|
gr.Markdown(AUTHOR) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("""## CATEGORY SELECTION""") |
|
gr.Markdown(load_description("description_category_selection")) |
|
drop_categories = gr.Dropdown( |
|
CATEGORIES, |
|
value=["alt.atheism", "talk.religion.misc"], |
|
multiselect=True, |
|
label="Categories", |
|
info="Please select up to two categories that you want to receive training on.", |
|
max_choices=2, |
|
interactive=True, |
|
) |
|
with gr.Row(): |
|
with gr.Tab("PARAMETERS GRID"): |
|
gr.Markdown(load_description("description_parameter_grid")) |
|
with gr.Row(): |
|
with gr.Column(): |
|
clf__alpha = gr.Textbox( |
|
label="Classifier Alpha (clf__alpha)", |
|
value="1.e-06, 1.e-05, 1.e-04", |
|
info="Due to practical considerations, this parameter was kept constant.", |
|
interactive=False, |
|
) |
|
vect__max_df = gr.Textbox( |
|
label="Vectorizer max_df (vect__max_df)", |
|
value="0.2, 0.4, 0.6, 0.8, 1.0", |
|
info="Values ranging from 0 to 1.0, separated by a comma.", |
|
interactive=True, |
|
) |
|
vect__min_df = gr.Textbox( |
|
label="Vectorizer min_df (vect__min_df)", |
|
value="1, 3, 5, 10", |
|
info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.", |
|
interactive=True, |
|
) |
|
with gr.Column(): |
|
vect__ngram_range = gr.Textbox( |
|
label="Vectorizer ngram_range (vect__ngram_range)", |
|
value="(1, 1), (1, 2)", |
|
info="""Tuples of integer values separated by a comma. For example an `ngram_range` of `(1, 1)` means only unigrams, `(1, 2)` means unigrams and bigrams, and `(2, 2)` means only bigrams.""", |
|
interactive=True, |
|
) |
|
vect__norm = gr.Textbox( |
|
label="Vectorizer norm (vect__norm)", |
|
value="l1, l2", |
|
info="'l1' or 'l2', separated by a comma", |
|
interactive=True, |
|
) |
|
|
|
with gr.Tab("DESCRIPTION OF PARAMETERS"): |
|
gr.Markdown("""### Classifier Alpha""") |
|
gr.Markdown(load_description("parameter_grid/alpha")) |
|
gr.Markdown("""### Vectorizer max_df""") |
|
gr.Markdown(load_description("parameter_grid/max_df")) |
|
gr.Markdown("""### Vectorizer min_df""") |
|
gr.Markdown(load_description("parameter_grid/min_df")) |
|
gr.Markdown("""### Vectorizer ngram_range""") |
|
gr.Markdown(load_description("parameter_grid/ngram_range")) |
|
gr.Markdown("""### Vectorizer norm""") |
|
gr.Markdown(load_description("parameter_grid/norm")) |
|
|
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
## MODEL PIPELINE |
|
```python |
|
pipeline = Pipeline( |
|
[ |
|
("vect", TfidfVectorizer()), |
|
("clf", ComplementNB()), |
|
] |
|
) |
|
``` |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("""## TRAINING""") |
|
with gr.Row(): |
|
brn_train = gr.Button("Train").style(container=False) |
|
|
|
gr.Markdown("## RESULTS") |
|
with gr.Row(): |
|
best_parameters = gr.Textbox(label="Best parameters") |
|
test_accuracy = gr.Textbox(label="Test accuracy") |
|
|
|
plot_trade = gr.Plot(label="") |
|
plot_coordinates = gr.Plot(label="") |
|
|
|
brn_train.click( |
|
train_model, |
|
inputs=[ |
|
drop_categories, |
|
vect__max_df, |
|
vect__min_df, |
|
vect__ngram_range, |
|
vect__norm, |
|
], |
|
outputs=[plot_trade, plot_coordinates, best_parameters, test_accuracy], |
|
) |
|
|
|
app.launch() |
|
|