Boba / app.py
asharma567
asda
d7e7aed
raw
history blame contribute delete
No virus
2.78 kB
#create the explainer and shap values
import pandas as pd
import streamlit as st
from st_aggrid import AgGrid, GridOptionsBuilder
from st_aggrid.shared import GridUpdateMode
import shap
import streamlit as st
import streamlit.components.v1 as components
import pickle
import xgboost
import joblib
def st_shap(plot, height=None):
shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
components.html(shap_html, height=height)
def scale_and_standardize(df):
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
scaler.fit(df)
scaled_df = scaler.transform(df)
return scaled_df
def aggrid_interactive_table(df: pd.DataFrame):
"""Creates an st-aggrid interactive table based on a dataframe.
Args:
df (pd.DataFrame]): Source dataframe
Returns:
dict: The selected row
"""
options = GridOptionsBuilder.from_dataframe(
df, enableRowGroup=True, enableValue=True, enablePivot=True
)
options.configure_side_bar()
options.configure_selection("single")
selection = AgGrid(
df,
enable_enterprise_modules=True,
gridOptions=options.build(),
theme="light",
update_mode=GridUpdateMode.MODEL_CHANGED,
allow_unsafe_jscode=True,
)
return selection
st.title("Boba Leaderboard")
#load joblib
df_M_character = pickle.load(open('df_M_character.pickle', 'rb'))
# explainer = pickle.load(open('explainer.pickle', 'rb'))
# shap_values = pickle.load(open('shap_values.pickle', 'rb'))
df_M_character_scale = scale_and_standardize(df_M_character)
author_idx_lookup = dict([(name, idx) for idx, name in enumerate(df_M_character.index)])
ppl_labels = joblib.load('ppl_labels.joblib')
# train XGBoost model
X, y = pd.DataFrame(
df_M_character_scale,
columns=df_M_character.columns
), ppl_labels
bst = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)
# explain the model's predictions using SHAP values
explainer = shap.TreeExplainer(bst)
shap_values = explainer.shap_values(X)
# shap.summary_plot(shap_values, X)
X = pd.DataFrame(
df_M_character_scale,
columns=df_M_character.columns)
df_leaderboard = pd.read_csv("leaderboard.csv")
df_leaderboard.rename(columns = {'Unnamed: 0':'Rank', 'author':'member'}, inplace=True)
selection = aggrid_interactive_table(df=df_leaderboard)
if selection:
try:
auth = selection["selected_rows"][0]['member']
except:
auth = df_leaderboard['member'].iloc[0]
st.write(auth)
author_idx = author_idx_lookup.get(auth)
# visualize the first prediction's explaination with default colors
st_shap(shap.force_plot(explainer.expected_value, shap_values[author_idx,:], X.iloc[author_idx,:]))