import math import json import gradio as gr import numpy as np import pandas as pd import plotly.express as px from sklearn.datasets import fetch_20newsgroups from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.model_selection import RandomizedSearchCV from sklearn.naive_bayes import ComplementNB from sklearn.pipeline import Pipeline CATEGORIES = [ "alt.atheism", "comp.graphics", "comp.os.ms-windows.misc", "comp.sys.ibm.pc.hardware", "comp.sys.mac.hardware", "comp.windows.x", "misc.forsale", "rec.autos", "rec.motorcycles", "rec.sport.baseball", "rec.sport.hockey", "sci.crypt", "sci.electronics", "sci.med", "sci.space", "soc.religion.christian", "talk.politics.guns", "talk.politics.mideast", "talk.politics.misc", "talk.religion.misc", ] PARAMETER_GRID = { "vect__max_df": (0.2, 0.4, 0.6, 0.8, 1.0), "vect__min_df": (1, 3, 5, 10), "vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams "vect__norm": ("l1", "l2"), "clf__alpha": np.logspace(-6, 6, 13), } def shorten_param(param_name): """Remove components' prefixes in param_name.""" if "__" in param_name: return param_name.rsplit("__", 1)[1] return param_name def train_model(categories): pipeline = Pipeline( [ ("vect", TfidfVectorizer()), ("clf", ComplementNB()), ] ) data_train = fetch_20newsgroups( subset="train", categories=categories, shuffle=True, random_state=42, remove=("headers", "footers", "quotes"), ) data_test = fetch_20newsgroups( subset="test", categories=categories, shuffle=True, random_state=42, remove=("headers", "footers", "quotes"), ) pipeline = Pipeline( [ ("vect", TfidfVectorizer()), ("clf", ComplementNB()), ] ) random_search = RandomizedSearchCV( estimator=pipeline, param_distributions=PARAMETER_GRID, n_iter=40, random_state=0, n_jobs=2, verbose=1, ) random_search.fit(data_train.data, data_train.target) best_parameters = json.dumps( random_search.best_estimator_.get_params(), indent=4, sort_keys=True, default=str, ) test_accuracy = random_search.score(data_test.data, data_test.target) cv_results = pd.DataFrame(random_search.cv_results_) cv_results = cv_results.rename(shorten_param, axis=1) param_names = [shorten_param(name) for name in PARAMETER_GRID.keys()] labels = { "mean_score_time": "CV Score time (s)", "mean_test_score": "CV score (accuracy)", } fig = px.scatter( cv_results, x="mean_score_time", y="mean_test_score", error_x="std_score_time", error_y="std_test_score", hover_data=param_names, labels=labels, ) fig.update_layout( title={ "text": "trade-off between scoring time and mean test score", "y": 0.95, "x": 0.5, "xanchor": "center", "yanchor": "top", } ) column_results = param_names + ["mean_test_score", "mean_score_time"] transform_funcs = dict.fromkeys(column_results, lambda x: x) # Using a logarithmic scale for alpha transform_funcs["alpha"] = math.log10 # L1 norms are mapped to index 1, and L2 norms to index 2 transform_funcs["norm"] = lambda x: 2 if x == "l2" else 1 # Unigrams are mapped to index 1 and bigrams to index 2 transform_funcs["ngram_range"] = lambda x: x[1] fig2 = px.parallel_coordinates( cv_results[column_results].apply(transform_funcs), color="mean_test_score", color_continuous_scale=px.colors.sequential.Viridis_r, labels=labels, ) fig2.update_layout( title={ "text": "Parallel coordinates plot of text classifier pipeline", "y": 0.99, "x": 0.5, "xanchor": "center", "yanchor": "top", } ) return fig, fig2, best_parameters, test_accuracy DESCRIPTION_PART1 = [ "The dataset used in this example is", "[The 20 newsgroups text dataset](https://scikit-learn.org/stable/datasets/real_world.html#newsgroups-dataset)", "which will be automatically downloaded, cached and reused for the document classification example.", ] DESCRIPTION_PART2 = [ "In this example, we tune the hyperparameters of", "a particular classifier using a", "[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV).", "For a demo on the performance of some other classifiers, see the", "[Classification of text documents using sparse features](https://scikit-learn.org/stable/auto_examples/text/plot_document_classification_20newsgroups.html#sphx-glr-auto-examples-text-plot-document-classification-20newsgroups-py) notebook.", ] CATEGORY_SELECTION_DESCRIPTION = [ "The task of text classification is easier when there is little overlap between the characteristic terms ", "of different topics. This is because the presence of common terms can make it difficult to distinguish between ", "different topics. On the other hand, when there is little overlap between the characteristic terms of different ", "topics, the task of text classification becomes easier, as the unique terms of each topic provide a solid basis ", "for accurately classifying the document into its respective category. Therefore, careful selection of characteristic", " terms for each topic is crucial to ensure accuracy in text classification." ] AUTHOR = """ Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html) """ with gr.Blocks(theme=gr.themes.Soft()) as app: with gr.Row(): with gr.Column(): gr.Markdown("# Sample pipeline for text feature extraction and evaluation") gr.Markdown(" ".join(DESCRIPTION_PART1)) gr.Markdown(" ".join(DESCRIPTION_PART2)) gr.Markdown(AUTHOR) with gr.Row(): with gr.Column(): gr.Markdown("""## CATEGORY SELECTION""") gr.Markdown("".join(CATEGORY_SELECTION_DESCRIPTION)) drop_categories = gr.Dropdown( CATEGORIES, value=["alt.atheism", "talk.religion.misc"], multiselect=True, label="Categories", info="Please select up to two categories that you want to receive training on.", max_choices=2, interactive=True, ) with gr.Row(): with gr.Column(): gr.Markdown( """ ## PARAMETERS GRID ```python { 'clf__alpha': array( [1.e-06, 1.e-05, 1.e-04,...] ), 'vect__max_df': (0.2, 0.4, 0.6, 0.8, 1.0), 'vect__min_df': (1, 3, 5, 10), 'vect__ngram_range': ((1, 1), (1, 2)), 'vect__norm': ('l1', 'l2') } ``` ## MODEL PIPELINE ```python pipeline = Pipeline( [ ("vect", TfidfVectorizer()), ("clf", ComplementNB()), ] ) ``` """ ) with gr.Row(): with gr.Column(): gr.Markdown("""## TRAINING""") with gr.Row(): brn_train = gr.Button("Train").style(container=False) gr.Markdown("## RESULTS") with gr.Row(): best_parameters = gr.Textbox(label="Best parameters") test_accuracy = gr.Textbox(label="Test accuracy") plot_trade = gr.Plot(label="") plot_coordinates = gr.Plot(label="") brn_train.click( train_model, [drop_categories], [plot_trade, plot_coordinates, best_parameters, test_accuracy], ) app.launch()