daniel-de-leon's picture
Add inference and explanation run time
bb78cda
import streamlit as st
import streamlit.components.v1 as components
from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
pipeline)
import shap
from PIL import Image
import time
st.set_option('deprecation.showPyplotGlobalUse', False)
output_width = 800
output_height = 300
rescale_logits = False
st.set_page_config(page_title='Text Classification with Shap')
st.title('Interpreting HF Pipeline Text Classification with Shap')
form = st.sidebar.form("Model Selection")
form.header('Model Selection')
model_name = form.text_input("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
form.form_submit_button("Submit")
@st.cache_data()
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model(model_name)
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
col1, col2, col3 = st.columns(3)
text = col1.text_area("Enter text input", value = "Classify me.")
start_time = time.time()
result = pred(text)
inference_time = time.time() - start_time
col3.write('')
col3.write(f'**Inference Time:** {inference_time: .4f}')
top_pred = result[0][0]['label']
col2.write('')
for label in result[0]:
col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
shap_values = explainer([text])
explanation_time = shap_values.compute_time
col3.write('')
col3.write(f'**Explanation Time:** {explanation_time: .4f}')
force_plot = shap.plots.text(shap_values, display=False)
bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
st.markdown("""
<style>
.big-font {
font-size:35px !important;
}
</style>
""", unsafe_allow_html=True)
st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
st.pyplot(bar_plot, clear_figure=True)
st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
components.html(force_plot, height=output_height, width=output_width, scrolling=True)