|
import time |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from sklearn import linear_model |
|
from sklearn.datasets import fetch_openml |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.utils._testing import ignore_warnings |
|
from sklearn.exceptions import ConvergenceWarning |
|
from sklearn.utils import shuffle |
|
|
|
def load_mnist(classes, n_samples): |
|
"""Load MNIST, select two classes, shuffle and return only n_samples.""" |
|
|
|
mnist = fetch_openml("mnist_784", version=1, as_frame=False, parser="pandas") |
|
|
|
|
|
mask = np.in1d(mnist.target, classes) |
|
|
|
X, y = shuffle(mnist.data[mask], mnist.target[mask], random_state=42) |
|
X, y = X[:n_samples], y[:n_samples] |
|
return X, y |
|
|
|
|
|
@ignore_warnings(category=ConvergenceWarning) |
|
def fit_and_score(estimator, max_iter, X_train, X_test, y_train, y_test): |
|
"""Fit the estimator on the train set and score it on both sets""" |
|
estimator.set_params(max_iter=max_iter) |
|
estimator.set_params(random_state=0) |
|
|
|
start = time.time() |
|
estimator.fit(X_train, y_train) |
|
|
|
fit_time = time.time() - start |
|
n_iter = estimator.n_iter_ |
|
train_score = estimator.score(X_train, y_train) |
|
test_score = estimator.score(X_test, y_test) |
|
|
|
return fit_time, n_iter, train_score, test_score |
|
|
|
def plot(classes, max_iterations, num_samples, n_iter_no_change, validation_fraction, tol): |
|
if len(classes) <2: |
|
raise gr.Error(f'Invalid number of classes (Numbers to be included in training)') |
|
max_iterations = int(max_iterations) |
|
num_samples = int(num_samples) |
|
n_iter_no_change = int(n_iter_no_change) |
|
validation_fraction = float(validation_fraction) |
|
tol = float(tol) |
|
|
|
estimator_dict = { |
|
"No stopping criterion": linear_model.SGDClassifier(n_iter_no_change=n_iter_no_change), |
|
"Training loss": linear_model.SGDClassifier( |
|
early_stopping=False, n_iter_no_change=n_iter_no_change, tol=0.1 |
|
), |
|
"Validation score": linear_model.SGDClassifier( |
|
early_stopping=True, n_iter_no_change=n_iter_no_change, tol=tol, validation_fraction=validation_fraction |
|
), |
|
} |
|
|
|
|
|
X, y = load_mnist(classes, n_samples=num_samples) |
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0) |
|
|
|
results = [] |
|
for estimator_name, estimator in estimator_dict.items(): |
|
for max_iter in range(1, max_iterations): |
|
|
|
fit_time, n_iter, train_score, test_score = fit_and_score( |
|
estimator, max_iter, X_train, X_test, y_train, y_test |
|
) |
|
|
|
results.append( |
|
(estimator_name, max_iter, fit_time, n_iter, train_score, test_score) |
|
) |
|
|
|
|
|
columns = [ |
|
"Stopping criterion", |
|
"max_iter", |
|
"Fit time (sec)", |
|
"n_iter_", |
|
"Train score", |
|
"Test score", |
|
] |
|
results_df = pd.DataFrame(results, columns=columns) |
|
|
|
|
|
lines = "Stopping criterion" |
|
x_axis = "max_iter" |
|
styles = ["-.", "--", "-"] |
|
|
|
|
|
fig1, axes1 = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(12, 4)) |
|
for ax, y_axis in zip(axes1, ["Train score", "Test score"]): |
|
for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)): |
|
group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style) |
|
ax.set_title(y_axis) |
|
ax.legend(title=lines) |
|
fig1.tight_layout() |
|
|
|
|
|
fig2, axes2 = plt.subplots(nrows=1, ncols=2, figsize=(12, 4)) |
|
for ax, y_axis in zip(axes2, ["n_iter_", "Fit time (sec)"]): |
|
for style, (criterion, group_df) in zip(styles, results_df.groupby(lines)): |
|
group_df.plot(x=x_axis, y=y_axis, label=criterion, ax=ax, style=style) |
|
ax.set_title(y_axis) |
|
ax.legend(title=lines) |
|
fig2.tight_layout() |
|
|
|
return fig1, fig2 |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(info) |
|
with gr.Row(): |
|
with gr.Column(): |
|
classes = gr.CheckboxGroup(["0", "1", "2","3","4","5","6","7","8","9"], value=['0','8'],label="Classes", info="Numbers to include in the training, for fast and stable training please choose 2 classes only") |
|
max_iterations = gr.Slider(label="Max Number of Iterations", value="50", minimum=5, maximum=50, step=1, info="Max Number of iterations to run SGD") |
|
num_samples = gr.Slider(label="Number of Samples", value="10000", minimum=1000, maximum=70000, step=100, info="Number of samples to include in the training") |
|
n_iter_no_change = gr.Slider(label="Number of Iterations with No Change", value="3", minimum=1, maximum=10, step=1, info="Maximum number of iterations with no score improvement by at leat tol, before stopping") |
|
validation_fraction = gr.Slider(label="Validation Fraction", value="0.2", minimum=0.05, maximum=0.9, step=0.01, info="Fraction of the training data to be used for validation") |
|
tol = gr.Textbox(label='Stopping Criterion', value="0.0001",info="The minimum improvement of score to be considered") |
|
out1 = gr.Plot() |
|
out2 = gr.Plot() |
|
|
|
btn = gr.Button("Run") |
|
btn.click(fn=plot, inputs=[classes, max_iterations, num_samples, n_iter_no_change, validation_fraction, tol], outputs=[out1, out2]) |
|
demo.launch() |
|
|