File size: 1,968 Bytes
eff32f1
5aaeb9a
 
 
 
 
7ed1e45
 
24e782a
8922b9c
5aaeb9a
 
 
 
 
 
 
 
2421d35
f401857
5aaeb9a
 
8922b9c
9c95ebe
5aaeb9a
 
 
 
1044b9f
5aaeb9a
 
 
 
 
 
 
e0b0602
5aaeb9a
 
 
6adbb0f
 
5aaeb9a
 
eff32f1
f401857
c8d4078
 
f401857
c8d4078
7721a54
c8d4078
47f0e0c
7721a54
f401857
 
9ea329a
f401857
 
088f056
9ea329a
676d046
9ea329a
 
5aaeb9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import torch
import numpy as np
from transformers import TrainingArguments, \
                         Trainer, AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification 
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
               'Math', 'Physics', 'Statistics']
labels = [i for i in range(len(categories))]

def print_probs(logits):
  probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
  ans = list(zip(probs,labels))
  ans.sort(reverse=True)
  sum = 0
  i = 0
  while sum <= 95:
    prob, idx = ans[i]
    text = categories[idx] + ": "+ str(np.round(prob,1)) + "%"
    st.write(text)
    sum+=prob
    i+=1
    
# @st.cache
def make_prediction(text):
  tokenized_text = tokenizer(text, return_tensors='pt')
  with torch.no_grad():
    pred_logits = model(**tokenized_text).logits
  st.markdown("### Category prediction:")
  print_probs(pred_logits[0])



tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model_name = "trained_model2"
model_path = model_name + '.zip'
model.load_state_dict(
    torch.load(
        model_path,
        map_location=torch.device("cpu")
    )
)

# MAIN
from PIL import Image
image = Image.open('logo.png')

st.image(image)

# st.markdown("<img src='https://centroderecursosmarista.org/wp-content/uploads/2013/05/arvix.jpg' class='center'>", unsafe_allow_html=True)
# st.markdown("# Arxiv.org category classifier")
st.markdown("# ")

st.markdown("### Article Title")
text1 = st.text_area("Введите название научной статьи для классификации", height=20)

st.markdown("### Article Abstract")

text2 = st.text_area("Введите описание статьи", height=200)
common_text = text1 + text2
if common_text != "":
    make_prediction(common_text)