File size: 2,510 Bytes
b8769be
4b59c2a
dee08fe
491d5a1
2c5279b
c24d1f1
b8769be
c24d1f1
2c5279b
 
 
 
 
 
 
 
 
0fbdf0a
2c5279b
 
 
 
beb608f
2c5279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e999374
 
 
 
 
4dc6495
f07e807
 
 
26fd4ec
d52f486
 
 
2c5279b
 
5e4fa04
f6a020d
5492f24
5e4fa04
e999374
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st
import numpy as np
import pandas as pd
import torch
import transformers
import tokenizers

@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_model():
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    model_name = 'distilbert-base-cased'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
    model.load_state_dict(torch.load('model_weights.pt', map_location=torch.device('cpu')))
    model.eval()
    return tokenizer, model
    
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})  
def predict(title, summary, tokenizer, model):
    text = title + "\n" + summary
    tokens = tokenizer.encode(text)
    with torch.no_grad():
        logits = model(torch.as_tensor([tokens]))[0]
        probs = torch.softmax(logits[-1, :], dim=-1).data.cpu().numpy()
    
    classes = np.flip(np.argsort(probs))
    sum_probs = 0
    ind = 0
    prediction = []
    prediction_probs = []
    while sum_probs < 0.95:
        prediction.append(label_to_theme[classes[ind]])
        prediction_probs.append(probs[classes[ind]])
        sum_probs += probs[classes[ind]]
        ind += 1
    
    return prediction, prediction_probs
   
@st.cache(suppress_st_warning=True) 
def get_results(prediction, prediction_probs):
    for prob in prediction_probs:
        prob = str("{:.2f}".format(100 * prob)) + "%"
    return pd.DataFrame({
         'Topic': prediction,
         'Confidence': prediction_probs,
         }).style.hide_index()
    
label_to_theme = {0: 'Computer science', 1: 'Economics', 2: 'Electrical Engineering and Systems Science', 3: 'Math',
                  4: 'Quantitative biology', 5: 'Quantitative Finance', 6: 'Statistics', 7: 'Physics'}

st.title("Arxiv articles classification")
st.markdown("This is an interface that can determine the article's topic based on its title and summary. Though it can work with title only, it is recommended that you provide summary if possible - this will result in a better prediction quality.")

tokenizer, model = load_model()

title = st.text_area(label='Title', height=100)
summary = st.text_area(label='Summary (optional)', height=250)
button = st.button('Run')

if button:
    prediction, prediction_probs = predict(title, summary, tokenizer, model)
    ans = get_results(prediction, prediction_probs)
    st.write('Results: ', ans)