File size: 1,898 Bytes
eff32f1
5aaeb9a
 
 
 
 
7ed1e45
 
24e782a
8922b9c
5aaeb9a
 
 
 
 
 
 
 
2421d35
f401857
5aaeb9a
 
8922b9c
9c95ebe
5aaeb9a
 
 
 
f401857
5aaeb9a
 
 
 
 
 
 
e0b0602
5aaeb9a
 
 
6adbb0f
 
5aaeb9a
 
eff32f1
f401857
c8d4078
 
f401857
c8d4078
7721a54
c8d4078
7721a54
 
f401857
 
fe9ea92
f401857
 
 
 
8922b9c
5aaeb9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import torch
import numpy as np
from transformers import TrainingArguments, \
                         Trainer, AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification 
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
               'Math', 'Physics', 'Statistics']
labels = [i for i in range(len(categories))]

def print_probs(logits):
  probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
  ans = list(zip(probs,labels))
  ans.sort(reverse=True)
  sum = 0
  i = 0
  while sum <= 95:
    prob, idx = ans[i]
    text = categories[idx] + ": "+ str(np.round(prob,1)) + "%"
    st.write(text)
    sum+=prob
    i+=1
    
# @st.cache
def make_prediction(text):
  tokenized_text = tokenizer(text, return_tensors='pt')
  with torch.no_grad():
    pred_logits = model(**tokenized_text).logits
  st.write("Category probability:")
  print_probs(pred_logits[0])



tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model_name = "trained_model2"
model_path = model_name + '.zip'
model.load_state_dict(
    torch.load(
        model_path,
        map_location=torch.device("cpu")
    )
)

# MAIN
from PIL import Image
image = Image.open('logo.png')

st.image(image)

# st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True)
st.markdown("# Arxiv.org category classifier")
st.markdown("# ")

st.markdown("### Article Title")
text = st.text_area("Введите название научной статьи для классификации", height=50)

st.markdown("### Article Abstract")
text = st.text_area("Введите описание статьи", height=400)


make_prediction(text)