Darkhan commited on
Commit
45c713e
1 Parent(s): 5c59066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -20,17 +20,20 @@ ID2CLS = {
20
  def classify(text, tokenizer, model):
21
  if not text:
22
  return [""]
23
- tokens = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids']
24
- probabilities = torch.softmax(model(tokens).logits, dim=1).detach().cpu().numpy()[0]
 
 
 
25
  total = 0
26
- ans = []
27
 
28
  for p in probabilities.argsort()[::-1]:
29
- if probabilities[p] + total < 0.9:
30
- total += probabilities[p]
31
- ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%']
32
 
33
- return ans
 
 
34
 
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -43,5 +46,4 @@ st.markdown("## Article classifier")
43
  title = st.text_area("title")
44
  text = st.text_area("article")
45
 
46
- for prediction in classify(title + text, tokenizer, model):
47
- st.markdown(prediction)
 
20
  def classify(text, tokenizer, model):
21
  if not text:
22
  return [""]
23
+
24
+ batch = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")
25
+ outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
26
+
27
+ probabilities = torch.softmax(outputs.logits, dim=1).detach().cpu().numpy()[0]
28
  total = 0
 
29
 
30
  for p in probabilities.argsort()[::-1]:
31
+ field = f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)} %'
32
+ st.markdown(field)
 
33
 
34
+ total += probabilities[p]
35
+ if total > 0.95:
36
+ break
37
 
38
 
39
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
46
  title = st.text_area("title")
47
  text = st.text_area("article")
48
 
49
+ classify(title + text, tokenizer, model)