Update app.py
Browse files
app.py
CHANGED
@@ -20,17 +20,20 @@ ID2CLS = {
|
|
20 |
def classify(text, tokenizer, model):
|
21 |
if not text:
|
22 |
return [""]
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
total = 0
|
26 |
-
ans = []
|
27 |
|
28 |
for p in probabilities.argsort()[::-1]:
|
29 |
-
|
30 |
-
|
31 |
-
ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%']
|
32 |
|
33 |
-
|
|
|
|
|
34 |
|
35 |
|
36 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
@@ -43,5 +46,4 @@ st.markdown("## Article classifier")
|
|
43 |
title = st.text_area("title")
|
44 |
text = st.text_area("article")
|
45 |
|
46 |
-
|
47 |
-
st.markdown(prediction)
|
|
|
20 |
def classify(text, tokenizer, model):
|
21 |
if not text:
|
22 |
return [""]
|
23 |
+
|
24 |
+
batch = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")
|
25 |
+
outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
|
26 |
+
|
27 |
+
probabilities = torch.softmax(outputs.logits, dim=1).detach().cpu().numpy()[0]
|
28 |
total = 0
|
|
|
29 |
|
30 |
for p in probabilities.argsort()[::-1]:
|
31 |
+
field = f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)} %'
|
32 |
+
st.markdown(field)
|
|
|
33 |
|
34 |
+
total += probabilities[p]
|
35 |
+
if total > 0.95:
|
36 |
+
break
|
37 |
|
38 |
|
39 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
46 |
title = st.text_area("title")
|
47 |
text = st.text_area("article")
|
48 |
|
49 |
+
classify(title + text, tokenizer, model)
|
|