Spaces:
Runtime error
Runtime error
Michael-Geis
commited on
Commit
·
7af8946
1
Parent(s):
3050c48
added slider bar for user controlled confidence on tags
Browse files- app.py +10 -6
- model.py +2 -2
- postprocess.py +4 -2
app.py
CHANGED
@@ -46,16 +46,16 @@ def parse_title(input_title):
|
|
46 |
return (title, subject_tags)
|
47 |
|
48 |
|
49 |
-
def outputs_from_id(input_id):
|
50 |
title, true_tags = parse_id(input_id)
|
51 |
-
predicted_tags = predict_from_text(title)
|
52 |
|
53 |
return title, predicted_tags, true_tags
|
54 |
|
55 |
|
56 |
-
def outputs_from_title(input_title):
|
57 |
title, true_tags = parse_title(input_title)
|
58 |
-
predicted_tags = predict_from_text(title)
|
59 |
|
60 |
return title, predicted_tags, true_tags
|
61 |
|
@@ -72,6 +72,8 @@ with gr.Blocks() as demo:
|
|
72 |
id_true = gr.Textbox(label="True tags")
|
73 |
id_button = gr.Button("Predict")
|
74 |
|
|
|
|
|
75 |
gr.Examples(
|
76 |
examples=[
|
77 |
"1706.03762",
|
@@ -102,11 +104,13 @@ with gr.Blocks() as demo:
|
|
102 |
)
|
103 |
|
104 |
id_button.click(
|
105 |
-
outputs_from_id,
|
|
|
|
|
106 |
)
|
107 |
title_button.click(
|
108 |
outputs_from_title,
|
109 |
-
inputs=title_input,
|
110 |
outputs=[title_title, title_predict, title_true],
|
111 |
)
|
112 |
|
|
|
46 |
return (title, subject_tags)
|
47 |
|
48 |
|
49 |
+
def outputs_from_id(input_id, threshold_probability):
|
50 |
title, true_tags = parse_id(input_id)
|
51 |
+
predicted_tags = predict_from_text(title, threshold_probability)
|
52 |
|
53 |
return title, predicted_tags, true_tags
|
54 |
|
55 |
|
56 |
+
def outputs_from_title(input_title, threshold_probability):
|
57 |
title, true_tags = parse_title(input_title)
|
58 |
+
predicted_tags = predict_from_text(title, threshold_probability)
|
59 |
|
60 |
return title, predicted_tags, true_tags
|
61 |
|
|
|
72 |
id_true = gr.Textbox(label="True tags")
|
73 |
id_button = gr.Button("Predict")
|
74 |
|
75 |
+
threshold_probability = gr.Slider(minimum=0, maximum=1)
|
76 |
+
|
77 |
gr.Examples(
|
78 |
examples=[
|
79 |
"1706.03762",
|
|
|
104 |
)
|
105 |
|
106 |
id_button.click(
|
107 |
+
outputs_from_id,
|
108 |
+
inputs=[id_input, threshold_probability],
|
109 |
+
outputs=[id_title, id_predict, id_true],
|
110 |
)
|
111 |
title_button.click(
|
112 |
outputs_from_title,
|
113 |
+
inputs=[title_input, threshold_probability],
|
114 |
outputs=[title_title, title_predict, title_true],
|
115 |
)
|
116 |
|
model.py
CHANGED
@@ -7,7 +7,7 @@ from preprocess import cleanse
|
|
7 |
from postprocess import postprocess
|
8 |
|
9 |
|
10 |
-
def predict_from_text(input_text):
|
11 |
## Load model and create pipeline
|
12 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
13 |
model = AutoModelForSequenceClassification.from_pretrained(
|
@@ -19,7 +19,7 @@ def predict_from_text(input_text):
|
|
19 |
clean_title = cleanse(input_text)
|
20 |
model_output = pipe(clean_title)
|
21 |
|
22 |
-
prediction = postprocess(model_output)
|
23 |
|
24 |
if len(prediction) == 0:
|
25 |
predict_output = "No matching tags."
|
|
|
7 |
from postprocess import postprocess
|
8 |
|
9 |
|
10 |
+
def predict_from_text(input_text, threshold_probability):
|
11 |
## Load model and create pipeline
|
12 |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
13 |
model = AutoModelForSequenceClassification.from_pretrained(
|
|
|
19 |
clean_title = cleanse(input_text)
|
20 |
model_output = pipe(clean_title)
|
21 |
|
22 |
+
prediction = postprocess(model_output, threshold_probability=threshold_probability)
|
23 |
|
24 |
if len(prediction) == 0:
|
25 |
predict_output = "No matching tags."
|
postprocess.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
import json
|
2 |
|
3 |
|
4 |
-
def postprocess(model_output):
|
5 |
with open("./data/arxiv-label-dict.json", "r") as file:
|
6 |
subject_dict = json.loads(file.read())
|
7 |
|
8 |
predicted_tags = [
|
9 |
-
result["label"]
|
|
|
|
|
10 |
]
|
11 |
|
12 |
return sorted([subject_dict[tag] for tag in predicted_tags])
|
|
|
1 |
import json
|
2 |
|
3 |
|
4 |
+
def postprocess(model_output, threshold_probability):
|
5 |
with open("./data/arxiv-label-dict.json", "r") as file:
|
6 |
subject_dict = json.loads(file.read())
|
7 |
|
8 |
predicted_tags = [
|
9 |
+
result["label"]
|
10 |
+
for result in model_output[0]
|
11 |
+
if result["score"] > threshold_probability
|
12 |
]
|
13 |
|
14 |
return sorted([subject_dict[tag] for tag in predicted_tags])
|