import os import numpy as np import torch import gradio as gr from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification # HuggingFace dataset to save the flagged examples HF_TOKEN = os.getenv('HF_TOKEN') hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic") # model path in hugginface model_path = "yabramuvdi/distilbert-wfh" tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN) model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN) # create a pipeline for predictions classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) # basic elements of page title = "Work From Home Predictor" description = "Demo application that predicts the pressence of work from home in any given sequence of text." article = "" # text at the end of the app examples = [ ["This is a work from home position.", 0.9], ["This position does not allow remote work.", 0.5], ] #%% def predict_wfh(input_text, input_slider): # get scores from model predictions = classifier(input_text)[0] # use selected threshold to classify as WFH prob_wfh = predictions[1]["score"] if prob_wfh > input_slider: wfh = 1 no_wfh = 0 else: wfh = 0 no_wfh = 1 return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") app = gr.Interface(fn=[predict_wfh], inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.998)], outputs=[label, text_output], theme="huggingface", title=title, description=description, article=article, examples=examples, allow_flagging="manual", flagging_options=["mistake"], flagging_callback=hf_saver ) #app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem") app.launch()