File size: 1,931 Bytes
eff32f1
5aaeb9a
 
 
 
 
7ed1e45
 
24e782a
8922b9c
5aaeb9a
 
 
 
 
 
 
 
2421d35
f401857
5aaeb9a
 
8922b9c
9c95ebe
5aaeb9a
 
 
 
b1cc36e
5aaeb9a
 
 
 
 
 
 
e0b0602
5aaeb9a
 
 
6adbb0f
 
5aaeb9a
 
eff32f1
f401857
c8d4078
 
f401857
c8d4078
7721a54
c8d4078
47f0e0c
7721a54
f401857
 
f51d611
f401857
 
f51d611
f401857
50ac7e6
 
5aaeb9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import torch
import numpy as np
from transformers import TrainingArguments, \
                         Trainer, AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification 
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
               'Math', 'Physics', 'Statistics']
labels = [i for i in range(len(categories))]

def print_probs(logits):
  probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
  ans = list(zip(probs,labels))
  ans.sort(reverse=True)
  sum = 0
  i = 0
  while sum <= 95:
    prob, idx = ans[i]
    text = categories[idx] + ": "+ str(np.round(prob,1)) + "%"
    st.write(text)
    sum+=prob
    i+=1
    
# @st.cache
def make_prediction(text):
  tokenized_text = tokenizer(text, return_tensors='pt')
  with torch.no_grad():
    pred_logits = model(**tokenized_text).logits
  st.markdown("### Category probability:")
  print_probs(pred_logits[0])



tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model_name = "trained_model2"
model_path = model_name + '.zip'
model.load_state_dict(
    torch.load(
        model_path,
        map_location=torch.device("cpu")
    )
)

# MAIN
from PIL import Image
image = Image.open('logo.png')

st.image(image)

# st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True)
# st.markdown("# Arxiv.org category classifier")
st.markdown("# ")

st.markdown("### Article Title")
text = st.text_area("Введите название научной статьи для классификации", height=20)

st.markdown("### Article Abstract")
text = st.text_area("Введите описание статьи", height=200)

if text is not None:
    make_prediction(text)