File size: 5,068 Bytes
2893724
acdedb4
a806f8f
acdedb4
 
a67a49a
2893724
 
 
 
 
 
 
acdedb4
 
 
 
 
2893724
 
 
 
 
 
f1a6f2c
 
2893724
 
acdedb4
 
 
5a20abc
 
acdedb4
5a20abc
acdedb4
 
2893724
5a20abc
acdedb4
 
 
 
 
 
2893724
 
 
acdedb4
2893724
acdedb4
f1a6f2c
 
2893724
f1a6f2c
acdedb4
 
f1a6f2c
acdedb4
f1a6f2c
acdedb4
2893724
a67a49a
2893724
b0390ea
2893724
acdedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1a6f2c
2893724
acdedb4
a806f8f
f1a6f2c
acdedb4
 
 
f1a6f2c
acdedb4
 
 
 
f1a6f2c
acdedb4
 
 
 
 
c4d9ff2
 
f1a6f2c
 
c4d9ff2
acdedb4
 
 
 
c4d9ff2
 
 
 
 
 
f1a6f2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
from ferret import Benchmark
from torch.nn.functional import softmax

DEFAULT_MODEL = "cardiffnlp/twitter-xlm-roberta-base-sentiment"


@st.cache()
def get_model(model_name):
    return AutoModelForSequenceClassification.from_pretrained(model_name)


@st.cache()
def get_config(model_name):
    return AutoConfig.from_pretrained(model_name)


def get_tokenizer(tokenizer_name):
    return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True)


def body():

    st.title("Explain individual texts")

    st.markdown(
        """
        You are working now on the *single instance* mode -- i.e., you will work and
        inspect one textual query at a time.
 
        Post-hoc explanation techniques disclose 🔎 the rationale behind a given prediction a model
        makes while detecting a sentiment out of a text. In a sense, they let you *poke* inside the model.

        But **who watches the watchers**? Are these explanations *accurate*? Can you *trust* them?

        Let's find out!

        Let's choose your favourite mode and let *ferret* do the rest.
        We will:

        1. download your model - if you're impatient, here it is a [cute video](https://www.youtube.com/watch?v=0Xks8t-SWHU) 🦜 for you;
        2. explain using *ferret*'s built-in methods ⚙️
        3. evaluate explanations with state-of-the-art **faithfulness metrics** 🚀
        
        """
    )

    col1, col2 = st.columns([3, 1])
    with col1:
        model_name = st.text_input("HF Model", DEFAULT_MODEL)
        config = AutoConfig.from_pretrained(model_name)

    with col2:
        class_labels = list(config.id2label.values())
        target = st.selectbox(
            "Target",
            options=class_labels,
            index=1,
            help="Class label you want to explain.",
        )

    text = st.text_input("Text", "I love your style!")

    compute = st.button("Run")

    if compute and model_name:

        with st.spinner("Preparing the magic. Hang in there..."):
            model = get_model(model_name)
            tokenizer = get_tokenizer(model_name)
            bench = Benchmark(model, tokenizer)

        st.markdown("### Prediction")
        scores = bench.score(text)
        scores_str = ", ".join(
            [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)]
        )
        st.text(scores_str)

        with st.spinner("Computing Explanations.."):
            explanations = bench.explain(text, target=class_labels.index(target))

        st.markdown("### Explanations")
        st.dataframe(bench.show_table(explanations))
        st.caption("Darker red (blue) means higher (lower) contribution.")

        with st.spinner("Evaluating Explanations..."):
            evaluations = bench.evaluate_explanations(
                explanations, target=class_labels.index(target), apply_style=False
            )

        st.markdown("### Faithfulness Metrics")
        st.dataframe(bench.show_evaluation_table(evaluations))
        st.caption("Darker colors mean better performance.")

        st.markdown(
            """
            **Legend**

            - **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the explanation captures all the tokens needed to make the prediction. Higher is better.
            
            - **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., if the relevant tokens in the explanation are sufficient to make the prediction. Lower is better.

            - **Leave-On-Out TAU Correlation** (taucorr_loo) measures the Kendall rank correlation coefficient τ between the explanation and leave-one-out importances. Closer to 1 is better. 
            
            See the paper for details.
            """
        )

        # It is computed as the drop in the model probability if the relevant tokens of the explanations are removed. The higher the comprehensiveness, the more faithful is the explanation.

        # It is computed as the drop in the model probability if only the relevant tokens of the explanations are considered. The lower the sufficiency, the more faithful is the explanation since there is less change in the model prediction.

        # The latter are computed by omittig individual input tokens and measuring the variation on the model prediction. The closer the τ correlation is to 1, the more faithful is the explanation.

        st.markdown(
            """
            **In code, it would be as simple as**
        """
        )
        st.code(
            f"""
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from ferret import Benchmark

model = AutoModelForSequenceClassification.from_pretrained("{model_name}")
tokenizer = AutoTokenizer.from_pretrained("{model_name}")

bench = Benchmark(model, tokenizer)
explanations = bench.explain("{text}")
evaluations = bench.evaluate_explanations(explanations)
            """,
            language="python",
        )