oracat commited on
Commit
8dcbe0a
1 Parent(s): 429523a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -2
app.py CHANGED
@@ -14,13 +14,40 @@ def prepare_model():
14
  return (tokenizer, model)
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def process(text):
18
  """
19
  Translate incoming text to tokens and classify it
20
  """
21
  pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
22
  result = pipe(text)[0]
23
- return result["label"]
24
 
25
 
26
  tokenizer, model = prepare_model()
@@ -105,4 +132,4 @@ text = "\n".join([title, abstract])
105
  ## Output
106
 
107
  if len(text.strip()) > 0:
108
- st.markdown(f"<h4>Predicted class: {process(text)}</h4>", unsafe_allow_html=True)
 
14
  return (tokenizer, model)
15
 
16
 
17
+ def top_pct(preds, threshold=0.95):
18
+ """
19
+ Output top predictions and their scores
20
+ """
21
+ preds = sorted(preds, key=lambda x: -x["score"])
22
+
23
+ cum_score = 0
24
+ for i, item in enumerate(preds):
25
+ cum_score += item["score"]
26
+ if cum_score >= threshold:
27
+ break
28
+
29
+ preds = preds[: (i + 1)]
30
+
31
+ return preds
32
+
33
+
34
+ def format_predictions(preds) -> str:
35
+ """
36
+ Prepare predictions and their scores for printing to the user
37
+ """
38
+ out = ""
39
+ for i, item in enumerate(preds):
40
+ out += f"{i+1}. **{item['label']}** *(score {item['score']:.2f})*\n"
41
+ return out
42
+
43
+
44
  def process(text):
45
  """
46
  Translate incoming text to tokens and classify it
47
  """
48
  pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
49
  result = pipe(text)[0]
50
+ return format_predictions(top_pct(result))
51
 
52
 
53
  tokenizer, model = prepare_model()
 
132
  ## Output
133
 
134
  if len(text.strip()) > 0:
135
+ st.markdown(f"{process(text)}", unsafe_allow_html=True)