Michael-Geis commited on
Commit
7af8946
1 Parent(s): 3050c48

added slider bar for user controlled confidence on tags

Browse files
Files changed (3) hide show
  1. app.py +10 -6
  2. model.py +2 -2
  3. postprocess.py +4 -2
app.py CHANGED
@@ -46,16 +46,16 @@ def parse_title(input_title):
46
  return (title, subject_tags)
47
 
48
 
49
- def outputs_from_id(input_id):
50
  title, true_tags = parse_id(input_id)
51
- predicted_tags = predict_from_text(title)
52
 
53
  return title, predicted_tags, true_tags
54
 
55
 
56
- def outputs_from_title(input_title):
57
  title, true_tags = parse_title(input_title)
58
- predicted_tags = predict_from_text(title)
59
 
60
  return title, predicted_tags, true_tags
61
 
@@ -72,6 +72,8 @@ with gr.Blocks() as demo:
72
  id_true = gr.Textbox(label="True tags")
73
  id_button = gr.Button("Predict")
74
 
 
 
75
  gr.Examples(
76
  examples=[
77
  "1706.03762",
@@ -102,11 +104,13 @@ with gr.Blocks() as demo:
102
  )
103
 
104
  id_button.click(
105
- outputs_from_id, inputs=id_input, outputs=[id_title, id_predict, id_true]
 
 
106
  )
107
  title_button.click(
108
  outputs_from_title,
109
- inputs=title_input,
110
  outputs=[title_title, title_predict, title_true],
111
  )
112
 
 
46
  return (title, subject_tags)
47
 
48
 
49
+ def outputs_from_id(input_id, threshold_probability):
50
  title, true_tags = parse_id(input_id)
51
+ predicted_tags = predict_from_text(title, threshold_probability)
52
 
53
  return title, predicted_tags, true_tags
54
 
55
 
56
+ def outputs_from_title(input_title, threshold_probability):
57
  title, true_tags = parse_title(input_title)
58
+ predicted_tags = predict_from_text(title, threshold_probability)
59
 
60
  return title, predicted_tags, true_tags
61
 
 
72
  id_true = gr.Textbox(label="True tags")
73
  id_button = gr.Button("Predict")
74
 
75
+ threshold_probability = gr.Slider(minimum=0, maximum=1)
76
+
77
  gr.Examples(
78
  examples=[
79
  "1706.03762",
 
104
  )
105
 
106
  id_button.click(
107
+ outputs_from_id,
108
+ inputs=[id_input, threshold_probability],
109
+ outputs=[id_title, id_predict, id_true],
110
  )
111
  title_button.click(
112
  outputs_from_title,
113
+ inputs=[title_input, threshold_probability],
114
  outputs=[title_title, title_predict, title_true],
115
  )
116
 
model.py CHANGED
@@ -7,7 +7,7 @@ from preprocess import cleanse
7
  from postprocess import postprocess
8
 
9
 
10
- def predict_from_text(input_text):
11
  ## Load model and create pipeline
12
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
13
  model = AutoModelForSequenceClassification.from_pretrained(
@@ -19,7 +19,7 @@ def predict_from_text(input_text):
19
  clean_title = cleanse(input_text)
20
  model_output = pipe(clean_title)
21
 
22
- prediction = postprocess(model_output)
23
 
24
  if len(prediction) == 0:
25
  predict_output = "No matching tags."
 
7
  from postprocess import postprocess
8
 
9
 
10
+ def predict_from_text(input_text, threshold_probability):
11
  ## Load model and create pipeline
12
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
13
  model = AutoModelForSequenceClassification.from_pretrained(
 
19
  clean_title = cleanse(input_text)
20
  model_output = pipe(clean_title)
21
 
22
+ prediction = postprocess(model_output, threshold_probability=threshold_probability)
23
 
24
  if len(prediction) == 0:
25
  predict_output = "No matching tags."
postprocess.py CHANGED
@@ -1,12 +1,14 @@
1
  import json
2
 
3
 
4
- def postprocess(model_output):
5
  with open("./data/arxiv-label-dict.json", "r") as file:
6
  subject_dict = json.loads(file.read())
7
 
8
  predicted_tags = [
9
- result["label"] for result in model_output[0] if result["score"] > 0.5
 
 
10
  ]
11
 
12
  return sorted([subject_dict[tag] for tag in predicted_tags])
 
1
  import json
2
 
3
 
4
+ def postprocess(model_output, threshold_probability):
5
  with open("./data/arxiv-label-dict.json", "r") as file:
6
  subject_dict = json.loads(file.read())
7
 
8
  predicted_tags = [
9
+ result["label"]
10
+ for result in model_output[0]
11
+ if result["score"] > threshold_probability
12
  ]
13
 
14
  return sorted([subject_dict[tag] for tag in predicted_tags])