File size: 1,778 Bytes
eff32f1
5aaeb9a
 
 
 
 
7ed1e45
 
24e782a
8922b9c
5aaeb9a
 
 
 
 
 
 
 
2421d35
f401857
5aaeb9a
 
8922b9c
9c95ebe
5aaeb9a
 
 
 
1044b9f
5aaeb9a
 
88367f6
 
 
 
 
 
 
 
 
 
 
eff32f1
f401857
c8d4078
 
f401857
c8d4078
7721a54
f401857
88367f6
9ea329a
f401857
 
088f056
9ea329a
88367f6
676d046
88367f6
9ea329a
 
5aaeb9a
88367f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st
import torch
import numpy as np
from transformers import TrainingArguments, \
                         Trainer, AutoTokenizer, DataCollatorWithPadding, \
                         AutoModelForSequenceClassification 
categories = ['Biology', 'Computer science', 'Economics', 'Electrics', 'Finance',
               'Math', 'Physics', 'Statistics']
labels = [i for i in range(len(categories))]

def print_probs(logits):
  probs = torch.nn.functional.softmax(logits, dim=0).numpy()*100
  ans = list(zip(probs,labels))
  ans.sort(reverse=True)
  sum = 0
  i = 0
  while sum <= 95:
    prob, idx = ans[i]
    text = categories[idx] + ": "+ str(np.round(prob,1)) + "%"
    st.write(text)
    sum+=prob
    i+=1
    
# @st.cache
def make_prediction(text):
  tokenized_text = tokenizer(text, return_tensors='pt')
  with torch.no_grad():
    pred_logits = model(**tokenized_text).logits
  st.markdown("### Category prediction:")
  print_probs(pred_logits[0])

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model_name = "trained_model2"
model_path = model_name + '.zip'
model.load_state_dict(
    torch.load(
        model_path,
        map_location=torch.device("cpu")
    )
)

# MAIN
from PIL import Image
image = Image.open('logo.png')

st.image(image)
st.markdown("# ")
st.markdown("### Article Title")

text1 = st.text_area("Введите название научной статьи для классификации", height=20)

st.markdown("### Article Abstract")

text2 = st.text_area("Введите описание статьи", height=200)

common_text = text1 + text2

if common_text != "":
    make_prediction(common_text)