g8a9 commited on
Commit
acdedb4
1 Parent(s): 43277af

add evaluation and prediction

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. single.py +77 -17
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
  transformers==4.20.1
2
- ferret-xai>=0.1.0
 
1
  transformers==4.20.1
2
+ ferret-xai>=0.2.0
single.py CHANGED
@@ -1,6 +1,10 @@
 
1
  import streamlit as st
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
  from ferret import Benchmark
 
 
 
4
 
5
 
6
  @st.cache()
@@ -8,40 +12,96 @@ def get_model(model_name):
8
  return AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
 
 
 
 
 
 
11
  def get_tokenizer(tokenizer_name):
12
  return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
13
 
14
 
15
  def body():
16
 
17
- st.title("Evaluate using *ferret* !")
18
-
19
  st.markdown(
20
  """
21
-
22
- ### 👋 Hi!
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- Insert down below your text, choose a model and fire up ferret. We will use
25
- *ferret* to:
26
- 1. produce explanations with all supported methods
27
- 2. evaluate explanations on state-of-the-art **faithfulness metrics**.
 
 
 
28
  """
29
  )
30
 
31
- col1, col2 = st.columns([1, 1])
32
  with col1:
33
- model_name = st.text_input("HF Model", "g8a9/bert-base-cased_ami18")
34
  with col2:
35
- tokenizer_name = st.text_input("HF Tokenizer", "bert-base-cased")
 
 
 
 
 
36
 
37
  text = st.text_input("Text")
38
 
39
  compute = st.button("Compute")
40
 
41
- if compute and model_name and tokenizer_name:
42
- model = get_model(model_name)
43
- tokenizer = get_tokenizer(tokenizer_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- bench = Benchmark(model, tokenizer)
46
- explanations = bench.explain(text)
47
  st.dataframe(bench.show_table(explanations))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ctypes import DEFAULT_MODE
2
  import streamlit as st
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
4
  from ferret import Benchmark
5
+ from torch.nn.functional import softmax
6
+
7
+ DEFAULT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
8
 
9
 
10
  @st.cache()
 
12
  return AutoModelForSequenceClassification.from_pretrained(model_name)
13
 
14
 
15
+ @st.cache()
16
+ def get_config(model_name):
17
+ return AutoConfig.from_pretrained(model_name)
18
+
19
+
20
  def get_tokenizer(tokenizer_name):
21
  return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)
22
 
23
 
24
  def body():
25
 
 
 
26
  st.markdown(
27
  """
28
+ # Welcome to the *ferret* showcase
29
+
30
+ You are working now on the *single instance* mode -- i.e., you will work and
31
+ inspect one textual query at a time.
32
+
33
+ ## Sentiment Analysis
34
+
35
+ Post-hoc explanation techniques discose the rationale behind a given prediction a model
36
+ makes while detecting a sentiment out of a text. In a sense the let you *poke* inside the model.
37
+
38
+ But **who watches the watchers**?
39
+
40
+ Let's find out!
41
 
42
+ Let's choose your favourite sentiment classification mode and let ferret do the rest.
43
+ We will:
44
+
45
+ 1. download your model - if you're impatient, here it is a [cute video](https://www.youtube.com/watch?v=0Xks8t-SWHU) 🦜 for you;
46
+ 2. explain using *ferret*'s built-in methods ⚙️
47
+ 3. evaluate explanations with state-of-the-art **faithfulness metrics** 🚀
48
+
49
  """
50
  )
51
 
52
+ col1, col2 = st.columns([3, 1])
53
  with col1:
54
+ model_name = st.text_input("HF Model", DEFAULT_MODEL)
55
  with col2:
56
+ target = st.selectbox(
57
+ "Target",
58
+ options=range(5),
59
+ index=1,
60
+ help="Positional index of your target class.",
61
+ )
62
 
63
  text = st.text_input("Text")
64
 
65
  compute = st.button("Compute")
66
 
67
+ if compute and model_name:
68
+
69
+ with st.spinner("Preparing the magic. Hang in there..."):
70
+ model = get_model(model_name)
71
+ tokenizer = get_tokenizer(model_name)
72
+ config = get_config(model_name)
73
+ bench = Benchmark(model, tokenizer)
74
+
75
+ st.markdown("### Prediction")
76
+ scores = bench.score(text)
77
+ scores_str = ", ".join(
78
+ [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)]
79
+ )
80
+ st.text(scores_str)
81
+
82
+ with st.spinner("Computing Explanations.."):
83
+ explanations = bench.explain(text, target=target)
84
 
85
+ st.markdown("### Explanations")
 
86
  st.dataframe(bench.show_table(explanations))
87
+
88
+ with st.spinner("Evaluating Explanations..."):
89
+ evaluations = bench.evaluate_explanations(
90
+ explanations, target=target, apply_style=False
91
+ )
92
+
93
+ st.markdown("### Faithfulness Metrics")
94
+ st.dataframe(bench.show_evaluation_table(evaluations))
95
+
96
+ st.markdown(
97
+ """
98
+ **Legend**
99
+
100
+ - **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the
101
+ explanation captures
102
+ - **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e.,
103
+ - **Leave-On-Out TAU Correlation** (taucorr_loo) measures
104
+
105
+ See the paper for details.
106
+ """
107
+ )