File size: 1,354 Bytes
fbb2b39
6f6d6c2
 
fbb2b39
6f6d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c713e
 
 
 
 
6f6d6c2
 
 
45c713e
 
6f6d6c2
45c713e
 
 
6f6d6c2
 
 
 
5c59066
6f6d6c2
 
 
 
 
 
 
45c713e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

MODEL_NAME = 'bert-base-cased'
MODEL_PATH = 'bert_model'

ID2CLS = {
    0: 'Computer Science',
    1: 'Economics',
    2: 'Electrical Engineering and Systems Science',
    3: 'Mathematics',
    4: 'Physics',
    5: 'Quantitative Biology',
    6: 'Quantitative Finance',
    7: 'Statistics'
}


def classify(text, tokenizer, model):
    if not text:
        return [""]

    batch = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")
    outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])

    probabilities = torch.softmax(outputs.logits, dim=1).detach().cpu().numpy()[0]
    total = 0

    for p in probabilities.argsort()[::-1]:
        field = f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)} %'
        st.markdown(field)

        total += probabilities[p]
        if total > 0.95:
            break


tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()

st.markdown("## Article classifier")

title = st.text_area("title")
text = st.text_area("article")

classify(title + text, tokenizer, model)