Zhenev's picture
Fixing the title in app.py
de9601f
raw history blame
No virus
5.13 kB
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn import svm, linear_model
from sklearn.metrics import auc
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import StratifiedKFold
import gradio as gr
from functools import partial
# Wrap the [Initial Analysis](https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html)
def auc_analysis(selected_data, n_folds, cls_name):
default_base = {"n_folds": 5}
# Load and prepare iris data
iris = datasets.load_iris()
X_iris, y_iris, target_names_iris = iris.data, iris.target, iris.target_names
X_iris, y_iris, target_names_iris = X_iris[y_iris != 2], y_iris[y_iris != 2], target_names_iris[0:-1]
n_samples_iris, n_features_iris = X_iris.shape
# Add noisy features to make the problem harder
random_state = np.random.RandomState(0)
X_iris = np.concatenate([X_iris, random_state.randn(n_samples_iris, 200 * n_features_iris)], axis=1)
dataset_list = {
"Iris": [X_iris, y_iris, target_names_iris]
}
# Load selected data
params = default_base.copy()
params.update({"n_folds": n_folds})
X, y, target_names = dataset_list[selected_data]
# Define classification model
svc_linear = svm.SVC(kernel="linear", probability=True, random_state=random_state)
logistic_regression = linear_model.LogisticRegression()
classification_models = {
"SVC - linear kernel": svc_linear,
"Logistic Regression": logistic_regression
}
classifier = classification_models[cls_name]
# Define folds
cv = StratifiedKFold(n_splits=params["n_folds"])
# ROC analysis
tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100)
fig, ax = plt.subplots(figsize=(6, 6))
for fold, (train, test) in enumerate(cv.split(X, y)):
classifier.fit(X[train], y[train])
viz = RocCurveDisplay.from_estimator(
classifier,
X[test],
y[test],
name=f"ROC fold {fold}",
alpha=0.5,
lw=1,
ax=ax,
)
interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)
ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
ax.plot(
mean_fpr,
mean_tpr,
color="b",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
lw=2,
alpha=0.8,
)
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
ax.fill_between(
mean_fpr,
tprs_lower,
tprs_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
ax.set(
xlim=[-0.05, 1.05],
ylim=[-0.05, 1.05],
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
)
ax.axis("square")
ax.legend(loc="lower right")
return fig
# Build the Demo
def iter_grid(n_rows, n_cols):
# create a grid using gradio Block
for _ in range(n_rows):
with gr.Row():
for _ in range(n_cols):
with gr.Column():
yield
input_models = ["SVC - linear kernel", "Logistic Regression"]
title = "πŸ”¬ Receiver Operating Characteristic (ROC) with Cross Validation"
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(
"This app demonstrates Receiver Operating Characteristic (ROC) metric estimate variability using "
"cross-validation. It shows the response of ROC and of its variance to different datasets, created from "
"K-fold cross-validation. "
"See the [source](https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html)"
" for more details.")
gr.Markdown(f'Available classification models: {", ".join(input_models)}.')
with gr.Row():
with gr.Column():
input_data = gr.Radio(
choices=["Iris"],
value="Iris",
label="Dataset",
info="Available datasets"
)
with gr.Column():
n_folds = gr.Radio(
[3, 4, 5, 6, 7, 8, 9], value=4, label="Folds", info="Number of cross-validation splits"
)
counter = 0
for _ in iter_grid(len(input_models) // 2 + len(input_models) % 2, 2):
if counter >= len(input_models):
break
input_model = input_models[counter]
plot = gr.Plot(label=input_model)
fn = partial(auc_analysis, cls_name=input_model)
input_data.change(fn=fn, inputs=[input_data, n_folds], outputs=plot)
n_folds.change(fn=fn, inputs=[input_data, n_folds], outputs=plot)
counter += 1
if __name__ == "__main__":
demo.launch()