File size: 1,873 Bytes
3ea75ed
 
 
 
 
 
 
 
3a9fe38
6a228c8
 
 
 
 
 
 
3a9fe38
 
f56b4fd
fa574e6
09ffd87
3ea75ed
09ffd87
3ea75ed
3a9fe38
3ea75ed
 
 
 
 
dd32181
3ea75ed
 
 
 
 
6a8a7ab
 
3ea75ed
 
 
 
cfde6c7
 
3ea75ed
00c24c9
19bf1f9
f494a7e
3ea75ed
f494a7e
3ea75ed
 
 
318a0ff
31552bd
f720766
fa574e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification 
from PIL import Image


labels = {
    "0":"biology",
    "1":"computer science",
    "2":"economics",
    "3":"electrics",
    "4":"finance",
    "5":"math",
    "6":"physics",
    "7":"statistics"
}

def predict_topic_by_title_and_abstract(text, model):
    tokens = tokenizer(text, return_tensors='pt', truncation=True)
    with torch.no_grad():
        logits = model(**tokens).logits
    probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
    ans = list(zip(probs,labels.values()))
    ans.sort(reverse=True)
    sum = 0
    i = 0
    while sum <= 95:
        prob, label = ans[i]
        st.write("it's topic \"" + str(label) + "\" with probability "+ str(np.round(prob,1)) + "%")
        sum += prob
        i += 1

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

with open('trained_model.pickle', 'rb') as handle:
    model = pickle.load(handle)

image = Image.open('logo.png')

st.image(image)
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.")
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.")

st.markdown('<style>textarea { background: #E8E8E8 !important;}</style>', unsafe_allow_html=True)

title = st.text_area("Write article title:", height=30)

abstract = st.text_area("Write article abstract:", height=60)

input_text = title + " " + abstract 

input_text = ''.join(filter(str.isalnum, input_text))

if len(input_text.split()) > 0:
    predict_topic_by_title_and_abstract(input_text, model)