explain-wine / app.py
katossky's picture
remove typing error
1da6f58
import gradio as gr
import datasets as ds
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from lime.lime_tabular import LimeTabularExplainer
wines = ds.load_dataset("katossky/wine-recognition", split='train')
wines = wines.to_pandas()
wines.columns = wines.columns.str.strip()
predictor = RandomForestClassifier(
n_estimators=1000, max_depth=5, n_jobs=4,
random_state=44 # for reproducibility
)
predictor.fit( wines.drop('label', axis=1), wines['label'] )
def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma):
instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1)
instance_np = instance_pd.to_numpy().squeeze()
explainer = LimeTabularExplainer(
training_data = wines.drop('label', axis=1), #.to_numpy(),
feature_names = wines.columns[1:].to_list(),
discretize_continuous = False, kernel_width=sigma
)
explanation = explainer.explain_instance(
instance_np,
predictor.predict_proba, #,
top_labels=3,
num_features=5
)
predictions = predictor.predict_proba(instance_pd)[0]
label = np.argmax(predictions)
confidences = {i: predictions[i] for i in range(3)}
return (
confidences,
explanation.as_pyplot_figure(label=label)
)
sigma_default = 0.75*(wines.shape[1]-1)**0.5
sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ')
instance_complete = wines.sample(1)
instance_part_1 = gr.Dataframe(
label = "Chemical properties of the wine",
headers = wines.columns[1:6].to_list(),
row_count = (1,"fixed"),
col_count = (5, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,1:6].values.tolist()
)
instance_part_2 = gr.Dataframe(
label = "",
show_label = False, # does not work
headers = wines.columns[6:10].to_list(),
row_count = (1,"fixed"),
col_count = (4, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,6:10].values.tolist()
)
instance_part_3 = gr.Dataframe(
label = "",
show_label = False, # does not work
headers = wines.columns[10:].to_list(),
row_count = (1,"fixed"),
col_count = (4, "fixed"),
datatype = "number",
value = instance_complete.iloc[:,10:].values.tolist()
)
demo = gr.Interface(
fn = plot_explanation,
inputs = [instance_part_1, instance_part_2, instance_part_3, sigma],
outputs = ["label", "plot"]
)
demo.launch()