File size: 2,848 Bytes
b8769be
4b59c2a
dee08fe
491d5a1
2c5279b
c24d1f1
b8769be
c24d1f1
2c5279b
 
 
 
 
e62c2f3
2c5279b
 
 
0fbdf0a
2c5279b
 
 
 
beb608f
2c5279b
 
 
 
 
 
 
 
 
fd871ea
2c5279b
 
 
 
 
 
 
62096a0
8393245
af657dc
f07e807
 
 
26fd4ec
af657dc
 
62096a0
d52f486
2c5279b
 
5e4fa04
f6a020d
5492f24
5e4fa04
e999374
 
 
a332a0d
62096a0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers

@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_model():
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    model_name = 'distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
    model.load_state_dict(torch.load('model_weights2.pt', map_location=torch.device('cpu')))
    model.eval()
    return tokenizer, model
    
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})  
def predict(title, summary, tokenizer, model):
    text = title + "\n" + summary
    tokens = tokenizer.encode(text)
    with torch.no_grad():
        logits = model(torch.as_tensor([tokens]))[0]
        probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
    
    classes = np.flip(np.argsort(probs))
    sum_probs = 0
    ind = 0
    prediction = []
    prediction_probs = []
    while sum_probs < 0.95:
        prediction.append(label_to_theme[classes[ind]])
        prediction_probs.append(str("{:.2f}".format(100 * probs[classes[ind]])) + "%")
        sum_probs += probs[classes[ind]]
        ind += 1
    
    return prediction, prediction_probs
   
@st.cache(suppress_st_warning=True) 
def get_results(prediction, prediction_probs):
    frame =  pd.DataFrame({'Category': prediction, 'Confidence': prediction_probs})
    frame.index = np.arange(1, len(frame) + 1)
    return frame
    
label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math',
                  4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'}

st.title("Arxiv articles classification")
st.markdown("<h1 style='text-align: center;'><img width=300px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/191:100/w_1200,h_630,c_limit/faces-icon.jpg'>", unsafe_allow_html=True)
st.markdown("This is an interface that can determine the article's category based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")

tokenizer, model = load_model()

title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)
button = st.button('Run')

if button:
    prediction, prediction_probs = predict(title, summary, tokenizer, model)
    ans = get_results(prediction, prediction_probs)
    if len(title + "\n" + summary) < 20:
        st.error("Your input is too short. It is probably not a real article, please try again.")
    else:
        st.subheader('Results:')
        st.write(ans)