Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline, set_seed, AutoTokenizer | |
| from uq import BertForUQSequenceClassification | |
| def predict(sentence): | |
| model_path = "tombm/bert-base-uncased-finetuned-cola" | |
| classifier = pipeline("text-classification", model=model_path, tokenizer=model_path) | |
| set_seed(12) | |
| label = classifier(sentence)[0]["label"] | |
| return label | |
| def uncertainty(sentence): | |
| model_path = "tombm/bert-base-uncased-finetuned-cola" | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = BertForUQSequenceClassification.from_pretrained(model_path) | |
| set_seed(12) | |
| test_input = tokenizer(sentence, return_tensors="pt") | |
| model.return_gp_cov = True | |
| _, gp_cov = model(**test_input) | |
| return gp_cov.item() | |
| with gr.Blocks() as demo: | |
| intro_str = """The *cola* dataset focuses on determining whether sentences are grammatically correct. | |
| Firstly, let's see how our finetuned model classifies two sentences, | |
| the first of which is correct (i.e. valid) and the second is not (i.e. invalid):""" | |
| gr.Markdown(value=intro_str) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(value="Good morning.", label="Input"), | |
| outputs="text", | |
| ) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox( | |
| value="This sentence is sentence, this is a correct sentence!", | |
| label="Input", | |
| ), | |
| outputs="text", | |
| ) | |
| explain_str = """As we can see, our model correctly classifies the first sentence, but misclassifies the second. | |
| Let's now inspect the uncertainties associated with each prediction generated by our GP head:""" | |
| gr.Markdown(value=explain_str) | |
| gr.Interface( | |
| fn=uncertainty, | |
| inputs=gr.Textbox(value="Good morning.", label="Input"), | |
| outputs=gr.Number(label="Variance from GP head"), | |
| ) # should have low uncertainty | |
| gr.Interface( | |
| fn=uncertainty, | |
| inputs=gr.Textbox( | |
| value="This sentence is sentence, this is a correct sentence!", | |
| label="Input", | |
| ), | |
| outputs=gr.Number(label="Variance from GP head"), | |
| ) # should have high uncertainty | |
| final_str = """We can see here that the variance for the misclassified example is much higher than for the correctly | |
| classified example. This is great, as now we have some indication of when our model might be uncertain!""" | |
| gr.Markdown(value=final_str) | |
| demo.launch() | |