|
from functools import partial |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
from sklearn import svm, datasets |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import ConfusionMatrixDisplay |
|
import gradio as gr |
|
|
|
|
|
def train_model(normalize): |
|
|
|
iris = datasets.load_iris() |
|
X = iris.data |
|
y = iris.target |
|
class_names = iris.target_names |
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) |
|
|
|
|
|
|
|
classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train) |
|
|
|
np.set_printoptions(precision=2) |
|
|
|
title = ( |
|
"Normalized confusion matrix" if normalize |
|
else "Confusion matrix, without normalization" |
|
) |
|
|
|
disp = ConfusionMatrixDisplay.from_estimator( |
|
classifier, |
|
X_test, |
|
y_test, |
|
display_labels=class_names, |
|
cmap=plt.cm.Blues, |
|
normalize='true' if normalize else None, |
|
) |
|
disp.ax_.set_title(title) |
|
|
|
return disp.figure_ |
|
|
|
|
|
title = "Confusion matrix" |
|
description = "Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set" |
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description) |
|
|
|
normalize = gr.Checkbox(label="Normalize") |
|
plot = gr.Plot(label="Confusion matrix") |
|
|
|
fn = partial(train_model) |
|
normalize.change(fn=fn, inputs=[normalize], outputs=plot) |
|
|
|
|
|
demo.launch(enable_queue=True, debug=True) |
|
|