Hnabil commited on
Commit
2d09772
·
1 Parent(s): d3e8b88

Add application file

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ from sklearn import svm, datasets
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import ConfusionMatrixDisplay
9
+ import gradio as gr
10
+
11
+
12
+ def train_model(normalize):
13
+ # import some data to play with
14
+ iris = datasets.load_iris()
15
+ X = iris.data
16
+ y = iris.target
17
+ class_names = iris.target_names
18
+
19
+ # Split the data into a training set and a test set
20
+ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
21
+
22
+ # Run classifier, using a model that is too regularized (C too low) to see
23
+ # the impact on the results
24
+ classifier = svm.SVC(kernel="linear", C=0.01).fit(X_train, y_train)
25
+
26
+ np.set_printoptions(precision=2)
27
+
28
+ title = (
29
+ "Normalized confusion matrix" if normalize
30
+ else "Confusion matrix, without normalization"
31
+ )
32
+
33
+ disp = ConfusionMatrixDisplay.from_estimator(
34
+ classifier,
35
+ X_test,
36
+ y_test,
37
+ display_labels=class_names,
38
+ cmap=plt.cm.Blues,
39
+ normalize='true' if normalize else None,
40
+ )
41
+ disp.ax_.set_title(title)
42
+
43
+ return disp.figure_
44
+
45
+
46
+ title = "Confusion matrix"
47
+ description = "Example of confusion matrix usage to evaluate the quality of the output of a classifier on the iris data set"
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown(f"## {title}")
50
+ gr.Markdown(description)
51
+
52
+ normalize = gr.Checkbox(label="Normalize")
53
+ plot = gr.Plot(label="Confusion matrix")
54
+
55
+ fn = partial(train_model)
56
+ normalize.change(fn=fn, inputs=[normalize], outputs=plot)
57
+
58
+
59
+ demo.launch(enable_queue=True, debug=True)