Hellisotherpeople commited on
Commit
1dce944
1 Parent(s): 05e4c27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py CHANGED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import shap
4
+ import streamlit as st
5
+ import streamlit.components.v1 as components
6
+ import torch
7
+ from datasets import load_dataset
8
+ from transformers import (AutoModelForCausalLM, AutoModelForQuestionAnswering,
9
+ AutoModelForSeq2SeqLM,
10
+ AutoModelForSequenceClassification, AutoTokenizer,
11
+ pipeline)
12
+
13
+ st.set_page_config(page_title="HF-SHAP")
14
+ st.title("HF-SHAP: A front end for SHAP")
15
+ st.caption("By Allen Roush")
16
+ st.caption("github: https://github.com/Hellisotherpeople")
17
+ st.caption("Linkedin: https://www.linkedin.com/in/allen-roush-27721011b/")
18
+ st.title("SHAP (SHapley Additive exPlanations)")
19
+ st.image("https://shap.readthedocs.io/en/latest/_images/shap_header.png", width = 700)
20
+ st.caption("By Lundberg, Scott M and Lee, Su-In")
21
+ st.caption("Full Citation: https://raw.githubusercontent.com/slundberg/shap/master/docs/references/shap_nips.bib")
22
+ st.caption("See on github:: https://github.com/slundberg/shap")
23
+
24
+
25
+ form = st.sidebar.form("Main Settings")
26
+
27
+ form.header("Main Settings")
28
+
29
+
30
+
31
+ task_done = form.selectbox("Which NLP task do you want to solve?", ["Text Generation", "Sentiment Analysis", "Translation", "Summarization"])
32
+
33
+
34
+
35
+
36
+
37
+ custom_doc = form.checkbox("Use a document from an existing dataset?", value = False)
38
+ if custom_doc:
39
+ dataset_name = form.text_area("Enter the name of the huggingface Dataset to do analysis of:", value = "Hellisotherpeople/DebateSum")
40
+ dataset_name_2 = form.text_area("Enter the name of the config for the dataset if it has one", value = "")
41
+ split_name = form.text_area("Enter the name of the split of the dataset that you want to use", value = "train")
42
+ number_of_records = form.number_input("Enter the number of documents that you want to analyze from the dataset", value = 200)
43
+ column_name = form.text_area("Enter the name of the column that we are doing analysis on (the X value)", value = "Full-Document")
44
+ index_to_analyze_start = form.number_input("Enter the index start of the document that you want to analyze of the dataset", value = 1)
45
+ index_to_analyze_end = form.number_input("Enter the index end of the document that you want to analyze of the dataset", value = 2)
46
+ form.caption("Multiple documents may not work on certain tasks")
47
+ else:
48
+ doc = st.text_area("Enter a custom document", value = "This is an example custom document")
49
+
50
+
51
+
52
+ if task_done == "Text Generation":
53
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Text Generation", value = "gpt2")
54
+ form.caption("This will download a new model, so it may take awhile or even break if the model is too large")
55
+ decoder = form.checkbox("Is this a decoder model?", value = True)
56
+ form.caption("This should be true for models like GPT-2, and false for models like BERT")
57
+ max_length = form.number_input("What's the max length of the text?", value = 50)
58
+ min_length = form.number_input("What's the min length of the text?", value = 20, max_value = max_length)
59
+ penalize_repetion = form.number_input("How strongly do we want to penalize repetition in the text generation?", value = 2)
60
+ sample = form.checkbox("Shall we use top-k and top-p decoding?", value = True)
61
+ form.caption("Setting this to false makes it greedy")
62
+ if sample:
63
+ top_k = form.number_input("What value of K should we use for Top-K sampling? Set to zero to disable", value = 50)
64
+ form.caption("In Top-K sampling, the K most likely next words are filtered and the probability mass is redistributed among only those K next words. ")
65
+ top_p = form.number_input("What value of P should we use for Top-p sampling? Set to zero to disable", value = 0.95, max_value = 1.0, min_value = 0.0)
66
+ form.caption("Top-p sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability p. The probability mass is then redistributed among this set of words.")
67
+ temperature = form.number_input("How spicy/interesting do we want our models output to be", value = 1.05, min_value = 0.0)
68
+ form.caption("Setting this higher decreases the likelihood of high probability words and increases the likelihood of low probability (and presumably more interesting) words")
69
+ form.caption("For more details on what these settings mean, see here: https://huggingface.co/blog/how-to-generate")
70
+
71
+
72
+ elif task_done == "Sentiment Analysis":
73
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Sentiment Analysis", value = "nateraw/bert-base-uncased-emotion")
74
+ rescale_logits = form.checkbox("Do we rescale the probabilities in terms of log odds?", value = False)
75
+ elif task_done == "Translation":
76
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "Helsinki-NLP/opus-mt-en-es")
77
+ elif task_done == "Summarization":
78
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Translation", value = "sshleifer/distilbart-xsum-12-6")
79
+ else:
80
+ model_name = form.text_area("Enter the name of the pre-trained model from transformers that we are using for Question Answering", value = "deepset/roberta-base-squad2")
81
+
82
+ form.header("Model Explanation Display Settings")
83
+ output_width = form.number_input("Enter the number of pixels for width of model explanation html display", value = 800)
84
+ output_height = form.number_input("Enter the number of pixels for height of model explanation html display", value = 2000)
85
+ form.form_submit_button("Submit")
86
+
87
+ @st.cache
88
+ def load_and_process_data(path, name, streaming, split_name, number_of_records):
89
+ dataset = load_dataset(path = path, name = name, streaming=streaming)
90
+ #return list(dataset)
91
+ dataset_head = dataset[split_name].take(number_of_records)
92
+ df = pd.DataFrame.from_dict(dataset_head)
93
+ return df[column_name]
94
+
95
+
96
+
97
+ @st.cache(allow_output_mutation=True)
98
+ def load_model(model_name):
99
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
100
+ if task_done == "Text Generation":
101
+ model = AutoModelForCausalLM.from_pretrained(model_name)
102
+ model.config.is_decoder=decoder
103
+ if sample == True:
104
+ model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "temperature": temperature, "top_k": top_k, "top_p" : top_p, "no_repeat_ngram_size": penalize_repetion}
105
+ else:
106
+ model.config.task_specific_params["text-generation"] = {"do_sample": sample, "max_length": max_length, "min_length": min_length, "no_repeat_ngram_size": penalize_repetion}
107
+
108
+ elif task_done == "Sentiment Analysis":
109
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
110
+ elif task_done == "Translation":
111
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
112
+ elif task_done == "Summarization":
113
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
114
+ elif task_done == "Question Answering":
115
+ #TODO: This one is going to be harder...
116
+ # https://shap.readthedocs.io/en/latest/example_notebooks/text_examples/question_answering/Explaining%20a%20Question%20Answering%20Transformers%20Model.html
117
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
118
+
119
+ return tokenizer, model
120
+
121
+ tokenizer, model = load_model(model_name)
122
+
123
+
124
+
125
+
126
+
127
+
128
+ if custom_doc:
129
+ df = load_and_process_data(dataset_name, dataset_name_2, True, split_name, number_of_records)
130
+ doc = list(df[index_to_analyze_start:index_to_analyze_end])
131
+ st.write(doc)
132
+
133
+ if task_done == "Sentiment Analysis":
134
+ pred = pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
135
+ explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
136
+ else:
137
+ explainer = shap.Explainer(model, tokenizer)
138
+
139
+ if custom_doc:
140
+ shap_values = explainer(doc)
141
+ else:
142
+ shap_values = explainer([doc])
143
+
144
+
145
+ ###REMEMBER YOU FIXED THE CODE IN SHAP, YOU NEED TO FORK IT
146
+
147
+ the_plot = shap.plots.text(shap_values, display = False)
148
+ components.html(the_plot, height = output_width, width = output_height)
149
+ st.caption("Scroll to see the full output!")