gp-uq-tester / app.py
tombm's picture
clean up app.py
f540b09
import gradio as gr
from transformers import pipeline, set_seed, AutoTokenizer
from uq import BertForUQSequenceClassification
def predict(sentence):
model_path = "tombm/bert-base-uncased-finetuned-cola"
classifier = pipeline("text-classification", model=model_path, tokenizer=model_path)
set_seed(12)
label = classifier(sentence)[0]["label"]
return label
def uncertainty(sentence):
model_path = "tombm/bert-base-uncased-finetuned-cola"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = BertForUQSequenceClassification.from_pretrained(model_path)
set_seed(12)
test_input = tokenizer(sentence, return_tensors="pt")
model.return_gp_cov = True
_, gp_cov = model(**test_input)
return gp_cov.item()
with gr.Blocks() as demo:
intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct.
Firstly, let's see how our finetuned model classifies two sentences,
the first of which is correct (i.e. valid) and the second is not (i.e. invalid):"""
gr.Markdown(value=intro_str)
gr.Interface(
fn=predict,
inputs=gr.Textbox(value="Good morning.", label="Input"),
outputs="text",
)
gr.Interface(
fn=predict,
inputs=gr.Textbox(
value="This sentence is sentence, this is a correct sentence!",
label="Input",
),
outputs="text",
)
explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second.
Let's now inspect the uncertainties associated with each prediction generated by our GP head:"""
gr.Markdown(value=explain_str)
gr.Interface(
fn=uncertainty,
inputs=gr.Textbox(value="Good morning.", label="Input"),
outputs=gr.Number(label="Variance from GP head"),
) # should have low uncertainty
gr.Interface(
fn=uncertainty,
inputs=gr.Textbox(
value="This sentence is sentence, this is a correct sentence!",
label="Input",
),
outputs=gr.Number(label="Variance from GP head"),
) # should have high uncertainty
final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly
classified example. This is great, as now we have some indication of when our model might be uncertain!"""
gr.Markdown(value=final_str)
demo.launch()