File size: 2,249 Bytes
6aca386
 
 
 
 
58e61ff
6aca386
 
 
 
 
 
 
58e61ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aca386
 
 
 
 
 
 
 
 
58e61ff
6aca386
 
 
 
 
 
 
58e61ff
 
 
 
 
 
 
 
 
 
 
 
6aca386
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
import pandas as pd
import torch
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
import os


def getTop95(predictions):
    for i in range(len(predictions)):
        vals, ids = torch.topk(predictions, i)
        if torch.sum(vals).item() >= 0.95:
            return ids
			

@st.cache(show_spinner=False)
def predict(text):
	classes = pd.read_csv('classes.csv')
	tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
	to_predict = title + '|' + summary
	X = tokenizer(to_predict, truncation=True, padding=True)
	tokens = torch.tensor(X['input_ids']).unsqueeze(0)
	mask = torch.tensor(X['attention_mask']).unsqueeze(0)
	model = DistilBertForSequenceClassification.from_pretrained(
	    os.getcwd(),
	    num_labels=len(classes)
	)
	model.eval()
	logits = model(tokens, mask)[0][0]
	softmax = torch.nn.Softmax()
	predictions = softmax(logits)
	ids = getTop95(predictions)
	return classes.tag.to_numpy()[ids]


st.set_page_config(
     page_title="ArXiv classificator",
     page_icon=":book:"
)

st.header("Theme classification of ArXiv articles")
st.markdown("""

Please enter title and summary (at least one is required) and oracul will predict classes of the arcticle according to taxonometry of ArXiv.

""")

with st.form(key='input_form'):
	title = st.text_input(label='Enter title of the article here')
	summary = st.text_area("Enter summary of the article here")
	submit = st.form_submit_button(label='Analyze')

if submit:
    if not title and not summary:
        st.markdown('Please enter at least one: title or summary')
    else:
        with st.spinner(text='Oracul thinks, please wait for his wise prediction'):
	        prediction = predict(title + '|' + summary)
        st.markdown("Most likely it is:")
        for tag in prediction[:5]:
	        st.markdown(f"- {tag}")
        st.markdown("Other possible variants:")
        st.write(', '.join(prediction[5:]))
        

hide_streamlit_style = """

    <style>

    #MainMenu {visibility: hidden;}

    footer {visibility: hidden;}

    </style>

"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)