File size: 2,490 Bytes
b51e9e8
5212a08
 
b51e9e8
 
5212a08
 
 
f540b09
caae2c6
5212a08
 
 
 
 
 
 
 
f540b09
5212a08
 
 
 
caae2c6
5212a08
 
 
 
 
 
 
 
 
 
 
caae2c6
5212a08
 
 
 
 
 
 
caae2c6
5212a08
 
 
 
 
 
 
 
 
caae2c6
5212a08
 
 
 
 
 
 
caae2c6
5212a08
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline, set_seed, AutoTokenizer
from uq import BertForUQSequenceClassification


def predict(sentence):
    model_path = "tombm/bert-base-uncased-finetuned-cola"
    classifier = pipeline("text-classification", model=model_path, tokenizer=model_path)
    set_seed(12)
    label = classifier(sentence)[0]["label"]
    return label


def uncertainty(sentence):
    model_path = "tombm/bert-base-uncased-finetuned-cola"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = BertForUQSequenceClassification.from_pretrained(model_path)

    set_seed(12)
    test_input = tokenizer(sentence, return_tensors="pt")
    model.return_gp_cov = True
    _, gp_cov = model(**test_input)

    return gp_cov.item()


with gr.Blocks() as demo:
    intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct.
               Firstly, let's see how our finetuned model classifies two sentences, 
               the first of which is correct (i.e. valid) and the second is not (i.e. invalid):"""
    gr.Markdown(value=intro_str)

    gr.Interface(
        fn=predict,
        inputs=gr.Textbox(value="Good morning.", label="Input"),
        outputs="text",
    )
    gr.Interface(
        fn=predict,
        inputs=gr.Textbox(
            value="This sentence is sentence, this is a correct sentence!",
            label="Input",
        ),
        outputs="text",
    )

    explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second.
        Let's now inspect the uncertainties associated with each prediction generated by our GP head:"""
    gr.Markdown(value=explain_str)

    gr.Interface(
        fn=uncertainty,
        inputs=gr.Textbox(value="Good morning.", label="Input"),
        outputs=gr.Number(label="Variance from GP head"),
    )  # should have low uncertainty
    gr.Interface(
        fn=uncertainty,
        inputs=gr.Textbox(
            value="This sentence is sentence, this is a correct sentence!",
            label="Input",
        ),
        outputs=gr.Number(label="Variance from GP head"),
    )  # should have high uncertainty

    final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly
                classified example. This is great, as now we have some indication of when our model might be uncertain!"""
    gr.Markdown(value=final_str)

demo.launch()