wfh-app-v2 / app.py
yabramuvdi's picture
Update app.py
b0b80b0
raw history blame
No virus
2.46 kB
import os
import numpy as np
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification
# HuggingFace dataset to save the flagged examples
HF_TOKEN = os.getenv('HF_TOKEN')
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic")
# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)
# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
tokenizer=tokenizer,
return_all_scores=True)
# basic elements of page
title = "Work From Home Predictor"
description = "Demo application that predicts the pressence of work from home in any given sequence of text."
article = "" # text at the end of the app
examples = [
["This is a work from home position.", 0.9],
["This position does not allow remote work.", 0.5],
]
#%%
def predict_wfh(input_text, input_slider):
# get scores from model
predictions = classifier(input_text)[0]
# use selected threshold to classify as WFH
prob_wfh = predictions[1]["score"]
if prob_wfh > input_slider:
wfh = 1
no_wfh = 0
else:
wfh = 0
no_wfh = 1
return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")
label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")
app = gr.Interface(fn=[predict_wfh],
inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.998)],
outputs=[label, text_output],
theme="huggingface",
title=title,
description=description,
article=article,
examples=examples,
allow_flagging="manual",
flagging_options=["mistake"],
flagging_callback=hf_saver
)
#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem")
app.launch()