import torch import gradio from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification # model path in hugginface model_path = "yabramuvdi/distilbert-wfh" tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") model = DistilBertForSequenceClassification.from_pretrained(model_path) # create a pipeline for predictions classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) # basic elements of page title = "Work From Home Predictor" description = "Demo application that predicts the pressence of work from home in any given sequence of text." article = "" # text at the end of the app examples = [ ["This is a work from home position", 0.998], ["This position does not allow working from home", 0.998], ] #%% def predict_wfh(input_text, input_slider): # get scores from model predictions = classifier(input_text)[0] # use selected threshold to classify as WFH prob_wfh = predictions[1]["score"] if prob_wfh > input_slider: wfh = 1 no_wfh = 0 else: wfh = 0 no_wfh = 1 return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}") label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification") text_output = gr.outputs.Textbox(type="auto", label="Predicted probability") app = gr.Interface(fn=[predict_wfh], inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold")], outputs=[label, text_output], theme="huggingface", title=title, description=description, article=article, examples=examples, allow_flagging="manual", flagging_options=["mistake", "borderline"] ) app.launch()