import gradio as gr from transformers import pipeline, set_seed, AutoTokenizer from uq import BertForUQSequenceClassification def predict(sentence): model_path = "tombm/bert-base-uncased-finetuned-cola" classifier = pipeline("text-classification", model=model_path, tokenizer=model_path) set_seed(12) label = classifier(sentence)[0]["label"] return label def uncertainty(sentence): model_path = "tombm/bert-base-uncased-finetuned-cola" tokenizer = AutoTokenizer.from_pretrained(model_path) model = BertForUQSequenceClassification.from_pretrained(model_path) set_seed(12) test_input = tokenizer(sentence, return_tensors="pt") model.return_gp_cov = True _, gp_cov = model(**test_input) return gp_cov.item() with gr.Blocks() as demo: intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct. Firstly, let's see how our finetuned model classifies two sentences, the first of which is correct (i.e. valid) and the second is not (i.e. invalid):""" gr.Markdown(value=intro_str) gr.Interface( fn=predict, inputs=gr.Textbox(value="Good morning.", label="Input"), outputs="text", ) gr.Interface( fn=predict, inputs=gr.Textbox( value="This sentence is sentence, this is a correct sentence!", label="Input", ), outputs="text", ) explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second. Let's now inspect the uncertainties associated with each prediction generated by our GP head:""" gr.Markdown(value=explain_str) gr.Interface( fn=uncertainty, inputs=gr.Textbox(value="Good morning.", label="Input"), outputs=gr.Number(label="Variance from GP head"), ) # should have low uncertainty gr.Interface( fn=uncertainty, inputs=gr.Textbox( value="This sentence is sentence, this is a correct sentence!", label="Input", ), outputs=gr.Number(label="Variance from GP head"), ) # should have high uncertainty final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly classified example. This is great, as now we have some indication of when our model might be uncertain!""" gr.Markdown(value=final_str) demo.launch()