wfh-app-v2 / app.py
yabramuvdi's picture
Update app.py
f19f070
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification
# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained(model_path)
# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
tokenizer=tokenizer,
return_all_scores=True)
# basic elements of page
title = "Work From Home Predictor"
description = "Demo application that predicts the pressence of work from home in any given sequence of text."
article = "" # text at the end of the app
examples = [
["This is a work from home position", 0.998],
["This position does not allow working from home", 0.998],
]
#%%
def predict_wfh(input_text, input_slider):
# get scores from model
predictions = classifier(input_text)[0]
# use selected threshold to classify as WFH
prob_wfh = predictions[1]["score"]
if prob_wfh > input_slider:
wfh = 1
no_wfh = 0
else:
wfh = 0
no_wfh = 1
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")
app = gr.Interface(fn=[predict_wfh],
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold")],
outputs=[label, text_output],
theme="huggingface",
title=title,
description=description,
article=article,
examples=examples,
allow_flagging="manual",
flagging_options=["mistake", "borderline"]
)
app.launch()