explain-wine / app.py
katossky's picture
remove typing error
1da6f58
raw history blame
No virus
2.39 kB
import gradio as gr
import datasets as ds
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from lime.lime_tabular import LimeTabularExplainer
wines = ds.load_dataset("katossky/wine-recognition", split='train')
wines = wines.to_pandas()
wines.columns = wines.columns.str.strip()
predictor = RandomForestClassifier(
n_estimators=1000, max_depth=5, n_jobs=4,
random_state=44 # for reproducibility
)
predictor.fit( wines.drop('label', axis=1), wines['label'] )
def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma):
instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1)
instance_np = instance_pd.to_numpy().squeeze()
explainer = LimeTabularExplainer(
training_data = wines.drop('label', axis=1), #.to_numpy(),
feature_names = wines.columns[1:].to_list(),
discretize_continuous = False, kernel_width=sigma
)
explanation = explainer.explain_instance(
instance_np,
predictor.predict_proba, #,
top_labels=3,
num_features=5
)
predictions = predictor.predict_proba(instance_pd)[0]
label = np.argmax(predictions)
confidences = {i: predictions[i] for i in range(3)}
return (
confidences,
explanation.as_pyplot_figure(label=label)
)
sigma_default = 0.75*(wines.shape[1]-1)**0.5
sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ')
instance_complete = wines.sample(1)
instance_part_1 = gr.Dataframe(
label = "Chemical properties of the wine",
headers = wines.columns[1:6].to_list(),
row_count = (1,"fixed"),
col_count = (5, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,1:6].values.tolist()
)
instance_part_2 = gr.Dataframe(
label = "",
show_label = False, # does not work
headers = wines.columns[6:10].to_list(),
row_count = (1,"fixed"),
col_count = (4, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,6:10].values.tolist()
)
instance_part_3 = gr.Dataframe(
label = "",
show_label = False, # does not work
headers = wines.columns[10:].to_list(),
row_count = (1,"fixed"),
col_count = (4, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,10:].values.tolist()
)
demo = gr.Interface(
fn = plot_explanation,
inputs = [instance_part_1, instance_part_2, instance_part_3, sigma],
outputs = ["label", "plot"]
)
demo.launch()