wrapper228 commited on
Commit
3ea75ed
1 Parent(s): 91c71aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import torch
4
+ import numpy as np
5
+ from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
6
+ from PIL import Image
7
+
8
+
9
+ # @st.cache
10
+ def predict_topic_by_title_and_abstract(text):
11
+ tokenized_text = tokenizer(text, return_tensors='pt')
12
+ with torch.no_grad():
13
+ logits = model(**tokenized_text).logits
14
+ probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
15
+ ans = list(zip(probs,labels.values()))
16
+ ans.sort(reverse=True)
17
+ sum = 0
18
+ i = 0
19
+ while sum <= 95:
20
+ prob, label = ans[i]
21
+ st.write("it's topic \"" label + "\" with probability "+ str(np.round(prob,1)) + "%")
22
+ sum += prob
23
+ i += 1
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
26
+
27
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
28
+ model.load_state_dict(
29
+ torch.load(
30
+ "./trained_model"
31
+ )
32
+ )
33
+
34
+ image = Image.open('logo.png')
35
+
36
+ st.image(image)
37
+ st.markdown("### To get an article topic prediction, please write down it's title, abstract, or both.")
38
+
39
+ title = st.text_area("Write article title:", height=30)
40
+
41
+ abstract = st.text_area("Write article abstract:", height=60)
42
+
43
+ input_text = title + " " + abstract
44
+
45
+ if input_text != "":
46
+ predict_topic_by_title_and_abstract(input_text)