yabramuvdi commited on
Commit
ba06738
1 Parent(s): 97c6f77

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio
3
+ from transformers import TextClassificationPipeline, DistilBertTokenizer, DistilBertForSequenceClassification
4
+
5
+ # model path in hugginface
6
+ model_path = "yabramuvdi/distilbert-wfh"
7
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
8
+ model = DistilBertForSequenceClassification.from_pretrained(model_path)
9
+
10
+ # create a pipeline for predictions
11
+ classifier = TextClassificationPipeline(model=model,
12
+ tokenizer=tokenizer,
13
+ return_all_scores=True)
14
+
15
+ # basic elements of page
16
+ title = "Work From Home Predictor"
17
+ description = "Demo application that predicts the pressence of work from home in any given sequence of text."
18
+ article = "" # text at the end of the app
19
+ examples = [
20
+ ["This is a work from home position", 0.998],
21
+ ["This position does not allow working from home", 0.998],
22
+ ]
23
+
24
+ #%%
25
+
26
+ def predict_wfh(input_text, input_slider):
27
+
28
+ # get scores from model
29
+ predictions = classifier(input_text)[0]
30
+
31
+ # use selected threshold to classify as WFH
32
+ prob_wfh = predictions[1]["score"]
33
+ if prob_wfh > input_slider:
34
+ wfh = 1
35
+ no_wfh = 0
36
+ else:
37
+ wfh = 0
38
+ no_wfh = 1
39
+
40
+ return({"Not work from home": no_wfh, "Work from home": wfh}, f"Probability of WFH: {np.round(prob_wfh, 3)}")
41
+
42
+
43
+ label = gr.outputs.Label(num_top_classes=1, type="confidences", label="Binary classification")
44
+ text_output = gr.outputs.Textbox(type="auto", label="Predicted probability")
45
+
46
+ app = gr.Interface(fn=[predict_wfh],
47
+ inputs=[gr.inputs.Textbox(lines=10, label="Input text"), gr.inputs.Slider(0, 1, 0.001, label="Classification threshold")],
48
+ outputs=[label, text_output],
49
+ theme="huggingface",
50
+ title=title,
51
+ description=description,
52
+ article=article,
53
+ examples=examples,
54
+ allow_flagging="manual",
55
+ flagging_options=["mistake", "borderline"]
56
+ )
57
+
58
+ app.launch()