Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification | |
# model path in hugginface | |
model_path = "yabramuvdi/distilbert-wfh" | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
model = DistilBertForSequenceClassification.from_pretrained(model_path) | |
# create a pipeline for predictions | |
classifier = TextClassificationPipeline(model=model, | |
tokenizer=tokenizer, | |
return_all_scores=True) | |
# basic elements of page | |
title = "Work From Home Predictor" | |
description = "Demo application that predicts the pressence of work from home in any given sequence of text." | |
article = "" # text at the end of the app | |
examples = [ | |
["This is a work from home position", 0.998], | |
["This position does not allow working from home", 0.998], | |
] | |
#%% | |
def predict_wfh(input_text, input_slider): | |
# get scores from model | |
predictions = classifier(input_text)[0] | |
# use selected threshold to classify as WFH | |
prob_wfh = predictions[1]["score"] | |
if prob_wfh > input_slider: | |
wfh = 1 | |
no_wfh = 0 | |
else: | |
wfh = 0 | |
no_wfh = 1 | |
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") | |
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") | |
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") | |
app = gr.Interface(fn=[predict_wfh], | |
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold")], | |
outputs=[label, text_output], | |
theme="huggingface", | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
allow_flagging="manual", | |
flagging_options=["mistake", "borderline"] | |
) | |
app.launch() |