import streamlit as st import streamlit.components.v1 as components from transformers import (AutoModelForSequenceClassification, AutoTokenizer, pipeline) import shap from PIL import Image import time st.set_option('deprecation.showPyplotGlobalUse', False) output_width = 800 output_height = 300 rescale_logits = False st.set_page_config(page_title='Text Classification with Shap') st.title('Interpreting HF Pipeline Text Classification with Shap') form = st.sidebar.form("Model Selection") form.header('Model Selection') model_name = form.text_input("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain") form.form_submit_button("Submit") @st.cache_data() def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(model_name) return tokenizer, model tokenizer, model = load_model(model_name) pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None) explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits) col1, col2, col3 = st.columns(3) text = col1.text_area("Enter text input", value = "Classify me.") start_time = time.time() result = pred(text) inference_time = time.time() - start_time col3.write('') col3.write(f'**Inference Time:** {inference_time: .4f}') top_pred = result[0][0]['label'] col2.write('') for label in result[0]: col2.write(f'**{label["label"]}**: {label["score"]: .2f}') shap_values = explainer([text]) explanation_time = shap_values.compute_time col3.write('') col3.write(f'**Explanation Time:** {explanation_time: .4f}') force_plot = shap.plots.text(shap_values, display=False) bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False) st.markdown(""" """, unsafe_allow_html=True) st.markdown(f'

Shap Bar Plot for {top_pred} Prediction

', unsafe_allow_html=True) st.pyplot(bar_plot, clear_figure=True) st.markdown('

Shap Interactive Force Plot

', unsafe_allow_html=True) components.html(force_plot, height=output_height, width=output_width, scrolling=True)