Spaces:
Running
Running
import os | |
import numpy as np | |
import torch | |
import gradio as gr | |
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification | |
# HuggingFace dataset to save the flagged examples | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic") | |
# model path in hugginface | |
model_path = "yabramuvdi/distilbert-wfh" | |
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) | |
# create a pipeline for predictions | |
classifier = TextClassificationPipeline(model=model, | |
tokenizer=tokenizer, | |
return_all_scores=True) | |
# basic elements of page | |
title = "Work From Home Predictor" | |
description = "Demo application that predicts the pressence of work from home in any given sequence of text." | |
article = "" # text at the end of the app | |
examples = [ | |
["This is a work from home position.", 0.9], | |
["This position does not allow remote work.", 0.5], | |
] | |
#%% | |
def predict_wfh(input_text, input_slider): | |
# get scores from model | |
predictions = classifier(input_text)[0] | |
# use selected threshold to classify as WFH | |
prob_wfh = predictions[1]["score"] | |
if prob_wfh > input_slider: | |
wfh = 1 | |
no_wfh = 0 | |
else: | |
wfh = 0 | |
no_wfh = 1 | |
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") | |
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") | |
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") | |
app = gr.Interface(fn=[predict_wfh], | |
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.998)], | |
outputs=[label, text_output], | |
theme="huggingface", | |
title=title, | |
description=description, | |
article=article, | |
examples=examples, | |
allow_flagging="manual", | |
flagging_options=["mistake"], | |
flagging_callback=hf_saver | |
) | |
#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem") | |
app.launch() |