File size: 2,375 Bytes
b8769be
4b59c2a
491d5a1
2c5279b
c24d1f1
b8769be
c24d1f1
2c5279b
 
 
 
 
 
 
 
 
0fbdf0a
2c5279b
 
 
 
beb608f
2c5279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05ebed
 
2c5279b
f07e807
 
 
26fd4ec
d52f486
 
 
2c5279b
 
5e4fa04
f6a020d
5e4fa04
2c5279b
 
ce8e0e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import streamlit as st
import numpy as np
import torch
import transformers
import tokenizers

@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_model():
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    model_name = 'distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
    model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
    model.eval()
    return tokenizer, model
    
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})  
def predict(title, summary, tokenizer, model):
    text = title + "\n" + summary
    tokens = tokenizer.encode(text)
    with torch.no_grad():
        logits = model(torch.as_tensor([tokens]))[0]
        probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
    
    classes = np.flip(np.argsort(probs))
    sum_probs = 0
    ind = 0
    prediction = []
    prediction_probs = []
    while sum_probs < 0.95:
        prediction.append(label_to_theme[classes[ind]])
        prediction_probs.append(probs[classes[ind]])
        sum_probs += probs[classes[ind]]
        ind += 1
    
    return prediction, prediction_probs
   
@st.cache(suppress_st_warning=True) 
def get_results(prediction, prediction_probs):
    ans = "Topic:\t\tConfidence:\n"
    for (topic, prob) in zip(prediction, prediction_probs):
        ans += topic + "\t\t" + str(prob) + "\n"
    return ans
    
label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math',
                  4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'}

st.title("Arxiv articles classification")
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")

tokenizer, model = load_model()

title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)

prediction, prediction_probs = predict(title, summary, tokenizer, model)
ans = get_results(prediction, prediction_probs)
st.markdown(ans)