File size: 1,856 Bytes
3ea75ed
 
 
 
 
 
 
 
f56b4fd
 
 
3ea75ed
 
09ffd87
3ea75ed
09ffd87
3ea75ed
 
 
 
 
 
 
7e51dc1
3ea75ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfde6c7
 
3ea75ed
00c24c9
19bf1f9
f494a7e
3ea75ed
f494a7e
3ea75ed
 
 
318a0ff
31552bd
f720766
3ea75ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification 
from PIL import Image


with open('labels.pickle', 'rb') as handle:
    labels = pickle.load(handle)

# @st.cache
def predict_topic_by_title_and_abstract(text):
    tokens = tokenizer(text, return_tensors='pt', truncation=True)
    with torch.no_grad():
        logits = model(**tokens).logits
    probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
    ans = list(zip(probs,labels.values()))
    ans.sort(reverse=True)
    sum = 0
    i = 0
    while sum <= 95:
        prob, label = ans[i]
        st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%")
        sum += prob
        i += 1

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model.load_state_dict(
    torch.load(
        "./trained_model"
    )
)

image = Image.open('logo.png')

st.image(image)
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.")
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.")

st.markdown('<style>textarea { background: #E8E8E8 !important;}</style>', unsafe_allow_html=True)

title = st.text_area("Write article title:", height=30)

abstract = st.text_area("Write article abstract:", height=60)

input_text = title + " " + abstract 

input_text = ''.join(filter(str.isalnum, input_text))

if len(input_text.split()) > 0:
    predict_topic_by_title_and_abstract(input_text)