HF-SHAP / app.py
Hellisotherpeople's picture
Update app.py
05559fa
import subprocess
import sys
##Lines 1-8 are necessary because the normal requirements.txt path for installing a package from disk doesn't work on HF spaces, thank you to Omar Sanseviero for the help!
import numpy as np
import pandas as pd
import shap
import streamlit as st
import streamlit.components.v1 as components
from datasets import load_dataset
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoTokenizer,
pipeline)
st.set_page_config(page_title="HF-SHAP")
st.title("HF-SHAP: A front end for SHAP")
st.caption("By Allen Roush")
st.caption("github: https://github.com/Hellisotherpeople")
st.caption("Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
st.title("SHAP (SHapley Additive exPlanations)")
st.image("https://shap.readthedocs.io/en/latest/_images/shap_header.png", width = 700)
st.caption("By Lundberg, Scott M and Lee, Su-In")
st.caption("Slightly modified by Allen Roush to fix a bug with text plotting not working outside of Jupyter Notebooks")
st.caption("Full Citation: https://raw.githubusercontent.com/slundberg/shap/master/docs/references/shap_nips.bib")
st.caption("See on github:: https://github.com/slundberg/shap")
st.caption("More details of how SHAP works: https://christophm.github.io/interpretable-ml-book/shap.html")
form = st.sidebar.form("Main Settings")
form.header("Main Settings")
task_done = form.selectbox("Which NLP task do you want to solve?", ["Text Generation", "Sentiment Analysis", "Translation", "Summarization"])
custom_doc = form.checkbox("Use a document from an existing dataset?", value = False)
if custom_doc:
dataset_name = form.text_area("Enter the name of the huggingface Dataset to do analysis of:", value = "Hellisotherpeople/DebateSum")
dataset_name_2 = form.text_area("Enter the name of the config for the dataset if it has one", value = "")
split_name = form.text_area("Enter the name of the split of the dataset that you want to use", value = "train")
number_of_records = form.number_input("Enter the number of documents that you want to analyze from the dataset", value = 200)
column_name = form.text_area("Enter the name of the column that we are doing analysis on (the X value)", value = "Full-Document")
index_to_analyze_start = form.number_input("Enter the index start of the document that you want to analyze of the dataset", value = 1)
index_to_analyze_end = form.number_input("Enter the index end of the document that you want to analyze of the dataset", value = 2)
form.caption("Multiple documents may not work on certain tasks")
else:
doc = st.text_area("Enter a custom document", value = "This is an example custom document")
if task_done == "Text Generation":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
decoder = form.checkbox("Is this a decoder model?", value = True)
form.caption("This should be true for models like GPT-2, and false for models like BERT")
max_length = form.number_input("What's the max length of the text?", value = 50)
min_length = form.number_input("What's the min length of the text?", value = 20, max_value = max_length)
penalize_repetion = form.number_input("How strongly do we want to penalize repetition in the text generation?", value = 2)
sample = form.checkbox("Shall we use top-k and top-p decoding?", value = True)
form.caption("Setting this to false makes it greedy")
if sample:
top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50)
form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ")
top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0)
form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.")
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
elif task_done == "Sentiment Analysis":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Sentiment Analysis", value = "nateraw/bert-base-uncased-emotion")
rescale_logits = form.checkbox("Do we rescale the probabilities in terms of log odds?", value = False)
elif task_done == "Translation":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "Helsinki-NLP/opus-mt-en-es")
elif task_done == "Summarization":
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "sshleifer/distilbart-xsum-12-1")
else:
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Question Answering", value = "deepset/roberta-base-squad2")
form.header("Model Explanation Display Settings")
output_width = form.number_input("Enter the number of pixels for width of model explanation html display", value = 800)
output_height = form.number_input("Enter the number of pixels for height of model explanation html display", value = 1000)
form.form_submit_button("Submit")
@st.cache
def load_and_process_data(path, name, streaming, split_name, number_of_records):
dataset = load_dataset(path = path, name = name, streaming=streaming)
#return list(dataset)
dataset_head = dataset[split_name].take(number_of_records)
df = pd.DataFrame.from_dict(dataset_head)
return df[column_name]
@st.cache(allow_output_mutation=True)
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if task_done == "Text Generation":
model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.is_decoder=decoder
if sample == True:
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "temperature": temperature, "top_k": top_k, "top_p" : top_p, "no_repeat_ngram_size": penalize_repetion}
else:
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "no_repeat_ngram_size": penalize_repetion}
elif task_done == "Sentiment Analysis":
model = AutoModelForSequenceClassification.from_pretrained(model_name)
elif task_done == "Translation":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
elif task_done == "Summarization":
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
elif task_done == "Question Answering":
#TODO: This one is going to be harder...
# https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/question_answering/Explaining%20a%20Question%20Answering%20Transformers%20Model.html
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model(model_name)
if custom_doc:
df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
doc = list(df[index_to_analyze_start:index_to_analyze_end])
st.write(doc)
if task_done == "Sentiment Analysis":
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
else:
explainer = shap.Explainer(model, tokenizer)
if custom_doc:
shap_values = explainer(doc)
else:
shap_values = explainer([doc])
the_plot = shap.plots.text(shap_values, display = False)
st.caption("The plot is interactive! Try Hovering over or clicking on the input or output text")
components.html(the_plot, height = output_height, width = output_width, scrolling = True)