Spaces:
Sleeping
Sleeping
Hellisotherpeople
commited on
Commit
•
1dce944
1
Parent(s):
05e4c27
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import shap
|
4 |
+
import streamlit as st
|
5 |
+
import streamlit.components.v1 as components
|
6 |
+
import torch
|
7 |
+
from datasets import load_dataset
|
8 |
+
from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
|
9 |
+
AutoModelForSeq2SeqLM,
|
10 |
+
AutoModelForSequenceClassification, AutoTokenizer,
|
11 |
+
pipeline)
|
12 |
+
|
13 |
+
st.set_page_config(page_title="HF-SHAP")
|
14 |
+
st.title("HF-SHAP: A front end for SHAP")
|
15 |
+
st.caption("By Allen Roush")
|
16 |
+
st.caption("github: https://github.com/Hellisotherpeople")
|
17 |
+
st.caption("Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
|
18 |
+
st.title("SHAP (SHapley Additive exPlanations)")
|
19 |
+
st.image("https://shap.readthedocs.io/en/latest/_images/shap_header.png", width = 700)
|
20 |
+
st.caption("By Lundberg, Scott M and Lee, Su-In")
|
21 |
+
st.caption("Full Citation: https://raw.githubusercontent.com/slundberg/shap/master/docs/references/shap_nips.bib")
|
22 |
+
st.caption("See on github:: https://github.com/slundberg/shap")
|
23 |
+
|
24 |
+
|
25 |
+
form = st.sidebar.form("Main Settings")
|
26 |
+
|
27 |
+
form.header("Main Settings")
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
task_done = form.selectbox("Which NLP task do you want to solve?", ["Text Generation", "Sentiment Analysis", "Translation", "Summarization"])
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
custom_doc = form.checkbox("Use a document from an existing dataset?", value = False)
|
38 |
+
if custom_doc:
|
39 |
+
dataset_name = form.text_area("Enter the name of the huggingface Dataset to do analysis of:", value = "Hellisotherpeople/DebateSum")
|
40 |
+
dataset_name_2 = form.text_area("Enter the name of the config for the dataset if it has one", value = "")
|
41 |
+
split_name = form.text_area("Enter the name of the split of the dataset that you want to use", value = "train")
|
42 |
+
number_of_records = form.number_input("Enter the number of documents that you want to analyze from the dataset", value = 200)
|
43 |
+
column_name = form.text_area("Enter the name of the column that we are doing analysis on (the X value)", value = "Full-Document")
|
44 |
+
index_to_analyze_start = form.number_input("Enter the index start of the document that you want to analyze of the dataset", value = 1)
|
45 |
+
index_to_analyze_end = form.number_input("Enter the index end of the document that you want to analyze of the dataset", value = 2)
|
46 |
+
form.caption("Multiple documents may not work on certain tasks")
|
47 |
+
else:
|
48 |
+
doc = st.text_area("Enter a custom document", value = "This is an example custom document")
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
if task_done == "Text Generation":
|
53 |
+
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
|
54 |
+
form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
|
55 |
+
decoder = form.checkbox("Is this a decoder model?", value = True)
|
56 |
+
form.caption("This should be true for models like GPT-2, and false for models like BERT")
|
57 |
+
max_length = form.number_input("What's the max length of the text?", value = 50)
|
58 |
+
min_length = form.number_input("What's the min length of the text?", value = 20, max_value = max_length)
|
59 |
+
penalize_repetion = form.number_input("How strongly do we want to penalize repetition in the text generation?", value = 2)
|
60 |
+
sample = form.checkbox("Shall we use top-k and top-p decoding?", value = True)
|
61 |
+
form.caption("Setting this to false makes it greedy")
|
62 |
+
if sample:
|
63 |
+
top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50)
|
64 |
+
form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ")
|
65 |
+
top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0)
|
66 |
+
form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.")
|
67 |
+
temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
|
68 |
+
form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
|
69 |
+
form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
|
70 |
+
|
71 |
+
|
72 |
+
elif task_done == "Sentiment Analysis":
|
73 |
+
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Sentiment Analysis", value = "nateraw/bert-base-uncased-emotion")
|
74 |
+
rescale_logits = form.checkbox("Do we rescale the probabilities in terms of log odds?", value = False)
|
75 |
+
elif task_done == "Translation":
|
76 |
+
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "Helsinki-NLP/opus-mt-en-es")
|
77 |
+
elif task_done == "Summarization":
|
78 |
+
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "sshleifer/distilbart-xsum-12-6")
|
79 |
+
else:
|
80 |
+
model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Question Answering", value = "deepset/roberta-base-squad2")
|
81 |
+
|
82 |
+
form.header("Model Explanation Display Settings")
|
83 |
+
output_width = form.number_input("Enter the number of pixels for width of model explanation html display", value = 800)
|
84 |
+
output_height = form.number_input("Enter the number of pixels for height of model explanation html display", value = 2000)
|
85 |
+
form.form_submit_button("Submit")
|
86 |
+
|
87 |
+
@st.cache
|
88 |
+
def load_and_process_data(path, name, streaming, split_name, number_of_records):
|
89 |
+
dataset = load_dataset(path = path, name = name, streaming=streaming)
|
90 |
+
#return list(dataset)
|
91 |
+
dataset_head = dataset[split_name].take(number_of_records)
|
92 |
+
df = pd.DataFrame.from_dict(dataset_head)
|
93 |
+
return df[column_name]
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
@st.cache(allow_output_mutation=True)
|
98 |
+
def load_model(model_name):
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
100 |
+
if task_done == "Text Generation":
|
101 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
102 |
+
model.config.is_decoder=decoder
|
103 |
+
if sample == True:
|
104 |
+
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "temperature": temperature, "top_k": top_k, "top_p" : top_p, "no_repeat_ngram_size": penalize_repetion}
|
105 |
+
else:
|
106 |
+
model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "no_repeat_ngram_size": penalize_repetion}
|
107 |
+
|
108 |
+
elif task_done == "Sentiment Analysis":
|
109 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
110 |
+
elif task_done == "Translation":
|
111 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
112 |
+
elif task_done == "Summarization":
|
113 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
114 |
+
elif task_done == "Question Answering":
|
115 |
+
#TODO: This one is going to be harder...
|
116 |
+
# https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/question_answering/Explaining%20a%20Question%20Answering%20Transformers%20Model.html
|
117 |
+
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
118 |
+
|
119 |
+
return tokenizer, model
|
120 |
+
|
121 |
+
tokenizer, model = load_model(model_name)
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
if custom_doc:
|
129 |
+
df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
|
130 |
+
doc = list(df[index_to_analyze_start:index_to_analyze_end])
|
131 |
+
st.write(doc)
|
132 |
+
|
133 |
+
if task_done == "Sentiment Analysis":
|
134 |
+
pred = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
|
135 |
+
explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
|
136 |
+
else:
|
137 |
+
explainer = shap.Explainer(model, tokenizer)
|
138 |
+
|
139 |
+
if custom_doc:
|
140 |
+
shap_values = explainer(doc)
|
141 |
+
else:
|
142 |
+
shap_values = explainer([doc])
|
143 |
+
|
144 |
+
|
145 |
+
###REMEMBER YOU FIXED THE CODE IN SHAP, YOU NEED TO FORK IT
|
146 |
+
|
147 |
+
the_plot = shap.plots.text(shap_values, display = False)
|
148 |
+
components.html(the_plot, height = output_width, width = output_height)
|
149 |
+
st.caption("Scroll to see the full output!")
|