File size: 3,139 Bytes
b36a34d
e780d1a
 
 
b36a34d
e780d1a
 
 
 
 
b36a34d
e780d1a
 
 
 
 
 
 
 
b36a34d
 
e780d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizer
import torch
import matplotlib.image as mpimg

img = mpimg.imread('./460.jpeg')
model_path = './models/pytorch_distilbert.bin'
vocab_path = './models/vocab_distilbert.bin'
device = torch.device('cpu')
MAX_LEN = 512

labels_description = {0: 'Computer Science',
                      1: 'Economics',
                      2: 'Electrical Engineering and Systems Science',
                      3: 'Mathematics',
                      4: 'Physics',
                      5: 'Quantitative Biology',
                      6: 'Quantitative Finance',
                      7: 'Statistics'}


class DistillBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistillBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 8)

    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output


def predict(text, model, human_readable=True):
    model.eval()
    text = " ".join(text.split())
    inputs = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        pad_to_max_length=True,
        return_token_type_ids=True,
        truncation=True
    )
    ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
    ids = torch.reshape(ids, (1, MAX_LEN))
    mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
    mask = torch.reshape(mask, (1, MAX_LEN))
    with torch.no_grad():
        outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist()
    result = []
    for i, v in enumerate(outputs):
        result.append((v, i))
    result.sort(reverse=True)
    pr = 0.0
    index = 0
    answer = []
    while pr < 0.95:
        pr += result[index][0]
        if not human_readable:
            answer.append(result[index][1])
        else:
            answer.append(labels_description[result[index][1]] + " {:.2f}%".format(100 * result[index][0]))
        index += 1

    return answer


tokenizer = DistilBertTokenizer.from_pretrained(vocab_path)
model = torch.load(model_path, map_location=torch.device(device))

st.markdown("### Hi! This is a service for determining the subject of an article.")
st.image(img)
st.markdown("### Just write the title and content in the areas below and click the \"Analyze\" button.")

text1 = st.text_area("Title")

text2 = st.text_area("Summary")

if st.button('Analyse'):
    with st.spinner("Wait..."):
        if text1 or text2:
                pred = predict(text1+"\n"+text2, model.to(device))
                st.success("\n\n".join(pred))
        else:
            st.error(f"You haven't written anything.")