import gradio as gr import datasets as ds import pandas as pd import numpy as np from sklearn.ensemble import RandomForestClassifier from lime.lime_tabular import LimeTabularExplainer wines = ds.load_dataset("katossky/wine-recognition", split='train') wines = wines.to_pandas() wines.columns = wines.columns.str.strip() predictor = RandomForestClassifier( n_estimators=1000, max_depth=5, n_jobs=4, random_state=44 # for reproducibility ) predictor.fit( wines.drop('label', axis=1), wines['label'] ) def plot_explanation(instance_part_1, instance_part_2, instance_part_3, sigma): instance_pd = pd.concat([instance_part_1, instance_part_2, instance_part_3], axis=1) instance_np = instance_pd.to_numpy().squeeze() explainer = LimeTabularExplainer( training_data = wines.drop('label', axis=1), #.to_numpy(), feature_names = wines.columns[1:].to_list(), discretize_continuous = False, kernel_width=sigma ) explanation = explainer.explain_instance( instance_np, predictor.predict_proba, #, top_labels=3, num_features=5 ) predictions = predictor.predict_proba(instance_pd)[0] label = np.argmax(predictions) confidences = {i: predictions[i] for i in range(3)} return ( confidences, explanation.as_pyplot_figure(label=label) ) sigma_default = 0.75*(wines.shape[1]-1)**0.5 sigma = gr.Slider(0.001, 2*sigma_default, value=sigma_default, label='σ') instance_complete = wines.sample(1) instance_part_1 = gr.Dataframe( label = "Chemical properties of the wine", headers = wines.columns[1:6].to_list(), row_count = (1,"fixed"), col_count = (5, "fixed"), datatype = "number", value = instance_complete.iloc[:,1:6].values.tolist() ) instance_part_2 = gr.Dataframe( label = "", show_label = False, # does not work headers = wines.columns[6:10].to_list(), row_count = (1,"fixed"), col_count = (4, "fixed"), datatype = "number", value = instance_complete.iloc[:,6:10].values.tolist() ) instance_part_3 = gr.Dataframe( label = "", show_label = False, # does not work headers = wines.columns[10:].to_list(), row_count = (1,"fixed"), col_count = (4, "fixed"), datatype = "number", value = instance_complete.iloc[:,10:].values.tolist() ) demo = gr.Interface( fn = plot_explanation, inputs = [instance_part_1, instance_part_2, instance_part_3, sigma], outputs = ["label", "plot"] ) demo.launch()