daniel-de-leon commited on
Commit
87f0205
1 Parent(s): 2b512a1
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +62 -41
  3. requirements.txt +4 -1
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.8.9
2
 
3
  WORKDIR /app
4
 
 
1
+ FROM python:3.9
2
 
3
  WORKDIR /app
4
 
app.py CHANGED
@@ -1,42 +1,63 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
-
5
- st.title('Uber pickups in NYC')
6
-
7
- DATE_COLUMN = 'date/time'
8
- DATA_URL = ('https://s3-us-west-2.amazonaws.com/'
9
- 'streamlit-demo-data/uber-raw-data-sep14.csv.gz')
10
-
11
- @st.cache_resource
12
- def load_data(nrows):
13
- data = pd.read_csv(DATA_URL, nrows=nrows)
14
- lowercase = lambda x: str(x).lower()
15
- data.rename(lowercase, axis='columns', inplace=True)
16
- data[DATE_COLUMN] = pd.to_datetime(data[DATE_COLUMN])
17
- return data
18
-
19
- data_load_state = st.text('Loading data...')
20
- data = load_data(10000)
21
- data_load_state.text("Done! (using st.cache)")
22
-
23
- if st.checkbox('Show raw data'):
24
- st.subheader('Raw data')
25
- st.write(data)
26
-
27
- st.subheader('Number of pickups by hour')
28
- hist_values = np.histogram(data[DATE_COLUMN].dt.hour, bins=24, range=(0,24))[0]
29
- st.bar_chart(hist_values)
30
-
31
- # Some number in the range 0-23
32
- hour_to_filter = st.slider('hour', 0, 23, 17)
33
- filtered_data = data[data[DATE_COLUMN].dt.hour == hour_to_filter]
34
-
35
- st.subheader('Map of all pickups at %s:00' % hour_to_filter)
36
- st.map(filtered_data)
37
-
38
- uploaded_file = st.file_uploader("Choose a file")
39
- if uploaded_file is not None:
40
- st.write(uploaded_file.name)
41
- bytes_data = uploaded_file.getvalue()
42
- st.write(len(bytes_data), "bytes")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ from transformers import (AutoModelForSequenceClassification, AutoTokenizer,
4
+ pipeline)
5
+ import shap
6
+ from PIL import Image
7
+
8
+ st.set_option('deprecation.showPyplotGlobalUse', False)
9
+ output_width = 800
10
+ output_height = 300
11
+ rescale_logits = False
12
+
13
+
14
+
15
+ st.set_page_config(page_title='Text Classification with Shap')
16
+ logo = Image.open('Intel-logo.png')
17
+ st.sidebar.image(logo)
18
+ st.title('Interpreting HF Pipeline Text Classification with Shap')
19
+
20
+ form = st.sidebar.form("Model Selection")
21
+ form.header('Model Selection')
22
+
23
+ model_name = form.text_input("Enter the name of the text classification LLM (note: model must be fine-tuned on a text classification task)", value = "Hate-speech-CNERG/bert-base-uncased-hatexplain")
24
+ form.form_submit_button("Submit")
25
+
26
+
27
+ @st.cache_data()
28
+ def load_model(model_name):
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
30
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
31
+
32
+ return tokenizer, model
33
+
34
+ tokenizer, model = load_model(model_name)
35
+ pred = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k=None)
36
+ explainer = shap.Explainer(pred, rescale_to_logits = rescale_logits)
37
+
38
+ col1, col2 = st.columns(2)
39
+ text = col1.text_area("Enter text input", value = "Classify me.")
40
+
41
+ result = pred(text)
42
+ top_pred = result[0][0]['label']
43
+ col2.write('')
44
+ for label in result[0]:
45
+ col2.write(f'**{label["label"]}**: {label["score"]: .2f}')
46
+
47
+ shap_values = explainer([text])
48
+
49
+ force_plot = shap.plots.text(shap_values, display=False)
50
+ bar_plot = shap.plots.bar(shap_values[0, :, top_pred], order=shap.Explanation.argsort.flip, show=False)
51
+
52
+ st.markdown("""
53
+ <style>
54
+ .big-font {
55
+ font-size:35px !important;
56
+ }
57
+ </style>
58
+ """, unsafe_allow_html=True)
59
+ st.markdown(f'<center><p class="big-font">Shap Bar Plot for <i>{top_pred}</i> Prediction</p></center>', unsafe_allow_html=True)
60
+ st.pyplot(bar_plot, clear_figure=True)
61
+
62
+ st.markdown('<center><p class="big-font">Shap Interactive Force Plot</p></center>', unsafe_allow_html=True)
63
+ components.html(force_plot, height=output_height, width=output_width, scrolling=True)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  streamlit
 
 
 
 
2
  numpy
3
- pandas
 
1
  streamlit
2
+ transformers
3
+ shap
4
+ torch
5
+ matplotlib
6
  numpy