mlkorra commited on
Commit
0677d87
·
1 Parent(s): bcd6a70

update app

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -6,6 +6,14 @@ import streamlit as st
6
  from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
7
  from transformers_interpret import SequenceClassificationExplainer
8
 
 
 
 
 
 
 
 
 
9
  @st.cache
10
  def load_model(text):
11
 
@@ -20,7 +28,7 @@ def load_model(text):
20
 
21
  results = nlp(text)
22
 
23
- return results,model,tokenizer
24
  #MASK_TOKEN = tokenizer.mask_token
25
  #masked_text = masked_text.replace("<mask>", MASK_TOKEN)
26
  #result_sentence = nlp(masked_text)
@@ -58,10 +66,9 @@ def app():
58
 
59
  if st.button('Visualize attributions'):
60
  with st.spinner("Visualizing .....") :
 
61
 
62
- cls_explainer = SequenceClassificationExplainer(model,tokenizer)
63
- word_attributions = cls_explainer(masked_text)
64
- st.write(cls_explainer.visualize())
65
 
66
  if __name__ == "__main__":
67
  app()
 
6
  from transformers import AutoModelForSequenceClassification,AutoTokenizer,pipeline
7
  from transformers_interpret import SequenceClassificationExplainer
8
 
9
+ @st.cache
10
+ def visualize(text):
11
+
12
+ cls_explainer = SequenceClassificationExplainer(model,tokenizer)
13
+ word_attributions = cls_explainer(masked_text)
14
+ cls_explainer.visualize('visualize.html')
15
+
16
+
17
  @st.cache
18
  def load_model(text):
19
 
 
28
 
29
  results = nlp(text)
30
 
31
+ return results
32
  #MASK_TOKEN = tokenizer.mask_token
33
  #masked_text = masked_text.replace("<mask>", MASK_TOKEN)
34
  #result_sentence = nlp(masked_text)
 
66
 
67
  if st.button('Visualize attributions'):
68
  with st.spinner("Visualizing .....") :
69
+ visualize(masked_text)
70
 
71
+
 
 
72
 
73
  if __name__ == "__main__":
74
  app()