katossky commited on
Commit
65e0688
1 Parent(s): d675b9b

create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import datasets as ds
3
+ import pandas as pd
4
+ import numpy as np
5
+ from sklearn.ensemble import RandomForestClassifier
6
+ from lime.lime_tabular import LimeTabularExplainer
7
+
8
+ wines = ds.load_dataset("katossky/wine-recognition", split='train')
9
+ wines = wines.to_pandas()
10
+ wines.columns = wines.columns.str.strip()
11
+
12
+ predictor = RandomForestClassifier(
13
+ n_estimators=1000, max_depth=5, n_jobs=4,
14
+ random_state=44 # for reproducibility
15
+ )
16
+
17
+ predictor.fit( wines.drop('label', axis=1), wines['label'] )
18
+
19
+ def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma):
20
+ instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1)
21
+ instance_np = instance_pd.to_numpy().squeeze()
22
+ explainer = lime.lime_tabular.LimeTabularExplainer(
23
+ training_data = wines.drop('label', axis=1), #.to_numpy(),
24
+ feature_names = wines.columns[1:].to_list(),
25
+ discretize_continuous = False, kernel_width=sigma
26
+ )
27
+ explanation = explainer.explain_instance(
28
+ instance_np,
29
+ predictor.predict_proba, #,
30
+ top_labels=3,
31
+ num_features=5
32
+ )
33
+ predictions = predictor.predict_proba(instance_pd)[0]
34
+ label = np.argmax(predictions)
35
+ confidences = {i: predictions[i] for i in range(3)}
36
+ return (
37
+ confidences,
38
+ explanation.as_pyplot_figure(label=label)
39
+ )
40
+
41
+ sigma_default = 0.75*(wines.shape[1]-1)**0.5
42
+ sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ')
43
+
44
+ instance_complete = wines.sample(1)
45
+
46
+ instance_part_1 = gr.Dataframe(
47
+ label = "Chemical properties of the wine",
48
+ headers = wines.columns[1:6].to_list(),
49
+ row_count = (1,"fixed"),
50
+ col_count = (5, "fixed"),
51
+ datatype = "number",
52
+ value = instance_complete.iloc[:,1:6].values.tolist()
53
+ )
54
+
55
+ instance_part_2 = gr.Dataframe(
56
+ label = "",
57
+ show_label = False, # does not work
58
+ headers = wines.columns[6:10].to_list(),
59
+ row_count = (1,"fixed"),
60
+ col_count = (4, "fixed"),
61
+ datatype = "number",
62
+ value = instance_complete.iloc[:,6:10].values.tolist()
63
+ )
64
+
65
+ instance_part_3 = gr.Dataframe(
66
+ label = "",
67
+ show_label = False, # does not work
68
+ headers = wines.columns[10:].to_list(),
69
+ row_count = (1,"fixed"),
70
+ col_count = (4, "fixed"),
71
+ datatype = "number",
72
+ value = instance_complete.iloc[:,10:].values.tolist()
73
+ )
74
+
75
+ demo = gr.Interface(
76
+ fn = plot_explanation,
77
+ inputs = [instance_part_1, instance_part_2, instance_part_3, sigma],
78
+ outputs = ["label", "plot"]
79
+ )
80
+ demo.launch()