File size: 2,458 Bytes
9aede51
d854c4f
ba06738
f19f070
ba06738
 
9aede51
 
 
 
ba06738
 
6587e6f
 
ba06738
 
 
 
 
 
 
 
 
 
 
f3c8e25
 
ba06738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d0f5dc
ba06738
 
 
 
 
 
 
9aede51
 
ba06738
 
b0b80b0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import numpy as np
import torch
import gradio as gr
from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification

# HuggingFace dataset to save the flagged examples
HF_TOKEN = os.getenv('HF_TOKEN')
hf_saver = gr.HuggingFaceDatasetSaver(HF_TOKEN, "wfh-problematic")

# model path in hugginface
model_path = "yabramuvdi/distilbert-wfh"
tokenizer = DistilBertTokenizer.from_pretrained(model_path, use_auth_token=HF_TOKEN)
model = DistilBertForSequenceClassification.from_pretrained(model_path, use_auth_token=HF_TOKEN)

# create a pipeline for predictions
classifier = TextClassificationPipeline(model=model,
                                        tokenizer=tokenizer,
                                        return_all_scores=True)
                                        
# basic elements of page
title = "Work From Home Predictor"
description = "Demo application that predicts the pressence of work from home in any given sequence of text."
article = ""    # text at the end of the app
examples = [
    ["This is a work from home position.", 0.9],
    ["This position does not allow remote work.", 0.5],
    ]

#%%

def predict_wfh(input_text, input_slider):
    
    # get scores from model
    predictions = classifier(input_text)[0]

    # use selected threshold to classify as WFH
    prob_wfh = predictions[1]["score"]
    if prob_wfh > input_slider:
        wfh = 1
        no_wfh = 0
    else:
        wfh = 0
        no_wfh = 1

    return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")


label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")

app = gr.Interface(fn=[predict_wfh],
                  inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold", default=0.998)],
                  outputs=[label, text_output],
                  theme="huggingface",
                  title=title,
                  description=description,
                  article=article,
                  examples=examples,
                  allow_flagging="manual",
                  flagging_options=["mistake"],
                  flagging_callback=hf_saver
                  )

#app.launch(auth=("yabra", "wfh123"), auth_message="Authentication Problem")
app.launch()