File size: 2,126 Bytes
ba06738
f19f070
ba06738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification

# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForSequenceClassification.from_pretrained(model_path)

# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
                                        tokenizer=tokenizer,
                                        return_all_scores=True)
                                        
# basic elements of page
title = "Work From Home Predictor"
description = "Demo application that predicts the pressence of work from home in any given sequence of text."
article = ""    # text at the end of the app
examples = [
    ["This is a work from home position", 0.998],
    ["This position does not allow working from home", 0.998],
    ]

#%%

def predict_wfh(input_text, input_slider):
    
    # get scores from model
    predictions = classifier(input_text)[0]

    # use selected threshold to classify as WFH
    prob_wfh = predictions[1]["score"]
    if prob_wfh > input_slider:
        wfh = 1
        no_wfh = 0
    else:
        wfh = 0
        no_wfh = 1

    return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")


label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")

app = gr.Interface(fn=[predict_wfh],
                  inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold")],
                  outputs=[label, text_output],
                  theme="huggingface",
                  title=title,
                  description=description,
                  article=article,
                  examples=examples,
                  allow_flagging="manual",
                  flagging_options=["mistake", "borderline"]
                  )

app.launch()