File size: 3,660 Bytes
a40adc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c767c4
0f11251
 
 
 
 
28ef4f2
a40adc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474c845
a40adc1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizer
import torch

model_path = './models/pytorch_distilbert.bin'
vocab_path = './models/vocab_distilbert.bin'
device = torch.device('cpu')
MAX_LEN = 512

labels_description = {0: 'Computer Science',
                      1: 'Economics',
                      2: 'Electrical Engineering and Systems Science',
                      3: 'Mathematics',
                      4: 'Physics',
                      5: 'Quantitative Biology',
                      6: 'Quantitative Finance',
                      7: 'Statistics'}


class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 8)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output


def predict(text, model, human_readable=True):
    model.eval()
    text = " ".join(text.split())
    inputs = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        pad_to_max_length=True,
        return_token_type_ids=True,
        truncation=True
    )
    ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
    ids = torch.reshape(ids, (1, MAX_LEN))
    mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
    mask = torch.reshape(mask, (1, MAX_LEN))
    with torch.no_grad():
        outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist()
    result = []
    for i, v in enumerate(outputs):
        result.append((v, i))
    result.sort(reverse=True)
    pr = 0.0
    index = 0
    answer = []
    while pr < 0.95:
        pr += result[index][0]
        if not human_readable:
            answer.append(result[index][1])
        else:
            answer.append(labels_description[result[index][1]] + " - {:.2f}%".format(100 * result[index][0]))
        index += 1

    return answer


@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model_and_tokenizer():
    return (torch.load(model_path, map_location=torch.device(device)), 
            DistilBertTokenizer.from_pretrained(vocab_path))
           
           
model, tokenizer = load_model_and_tokenizer()

st.markdown("### Hi! This is a service for determining the subject of an article.")
st.markdown("It can predict the following topics:\n"
            "* Computer Science\n"
            "* Economics\n"
            "* Electrical Engineering and Systems Science\n"
            "* Mathematics\n"
            "* Physics\n"
            "* Quantitative Biology\n"
            "* Quantitative Finance\n"
            "* Statistics\n")
st.markdown("#### Just write the title and abstract in the areas below and click the \"Analyze\" button.")

title = st.text_area("Title")

abstract = st.text_area("Abstract")

if st.button('Analyze'):
    with st.spinner("Wait..."):
        if not title and not abstract:
            st.error(f"You haven't written anything.")
        elif not title:
            st.error(f"You haven't written a title.")
        else:
            pred = predict(title+"\n"+abstract, model.to(device))
            st.success("\n\n".join(pred))