#create the explainer and shap values import pandas as pd import streamlit as st from st_aggrid import AgGrid, GridOptionsBuilder from st_aggrid.shared import GridUpdateMode import shap import streamlit as st import streamlit.components.v1 as components import pickle import xgboost import joblib def st_shap(plot, height=None): shap_html = f"{shap.getjs()}{plot.html()}" components.html(shap_html, height=height) def scale_and_standardize(df): from sklearn.preprocessing import StandardScaler scaler = StandardScaler() scaler.fit(df) scaled_df = scaler.transform(df) return scaled_df def aggrid_interactive_table(df: pd.DataFrame): """Creates an st-aggrid interactive table based on a dataframe. Args: df (pd.DataFrame]): Source dataframe Returns: dict: The selected row """ options = GridOptionsBuilder.from_dataframe( df, enableRowGroup=True, enableValue=True, enablePivot=True ) options.configure_side_bar() options.configure_selection("single") selection = AgGrid( df, enable_enterprise_modules=True, gridOptions=options.build(), theme="light", update_mode=GridUpdateMode.MODEL_CHANGED, allow_unsafe_jscode=True, ) return selection st.title("Boba Leaderboard") #load joblib df_M_character = pickle.load(open('df_M_character.pickle', 'rb')) # explainer = pickle.load(open('explainer.pickle', 'rb')) # shap_values = pickle.load(open('shap_values.pickle', 'rb')) df_M_character_scale = scale_and_standardize(df_M_character) author_idx_lookup = dict([(name, idx) for idx, name in enumerate(df_M_character.index)]) ppl_labels = joblib.load('ppl_labels.joblib') # train XGBoost model X, y = pd.DataFrame( df_M_character_scale, columns=df_M_character.columns ), ppl_labels bst = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100) # explain the model's predictions using SHAP values explainer = shap.TreeExplainer(bst) shap_values = explainer.shap_values(X) # shap.summary_plot(shap_values, X) X = pd.DataFrame( df_M_character_scale, columns=df_M_character.columns) df_leaderboard = pd.read_csv("leaderboard.csv") df_leaderboard.rename(columns = {'Unnamed: 0':'Rank', 'author':'member'}, inplace=True) selection = aggrid_interactive_table(df=df_leaderboard) if selection: try: auth = selection["selected_rows"][0]['member'] except: auth = df_leaderboard['member'].iloc[0] st.write(auth) author_idx = author_idx_lookup.get(auth) # visualize the first prediction's explaination with default colors st_shap(shap.force_plot(explainer.expected_value, shap_values[author_idx,:], X.iloc[author_idx,:]))