File size: 2,414 Bytes
4355387
 
de1989c
e1fd88d
4355387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675d241
c31188d
4355387
c31188d
 
675d241
c31188d
e1fd88d
c31188d
4355387
4915257
f8aac2d
4915257
4355387
f8aac2d
4355387
 
 
 
2b0ca1b
4355387
 
 
 
 
 
c31188d
 
4355387
 
 
2b0ca1b
4355387
 
 
 
 
 
5409e3f
 
 
 
de1989c
 
 
 
 
5409e3f
d276604
40d8d53
7d6731a
5409e3f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, DistilBertForSequenceClassification

my_model_name = "istassiy/ysda_2022_ml2_hw3_distilbert_base_uncased"

arxiv_code_to_topic = {
  'cs' : 'computer science',

  'q-bio' : 'biology',

  'q-fin' : 'finance',

  'astro-ph' : 'physics',
  'cond-mat' : 'physics',
  'gr-qc' : 'physics',
  'hep-ex' : 'physics',
  'hep-lat' : 'physics',
  'hep-ph' : 'physics',
  'hep-th' : 'physics',
  'math-ph' : 'physics',
  'nlin' : 'physics',
  'nucl-ex' : 'physics',
  'nucl-th' : 'physics',
  'quant-ph' : 'physics',
  'physics' : 'physics',

  'eess' : 'electrical engineering',

  'econ' : 'economics',

  'math' : 'mathematics',

  'stat' : 'statistics',
}

sorted_arxiv_topics = sorted(set(arxiv_code_to_topic.values()))

NUM_LABELS = len(sorted_arxiv_topics)

@st.cache(allow_output_mutation=True)
def load_tokenizer():
  tokenizer = AutoTokenizer.from_pretrained(my_model_name)
  return tokenizer

@st.cache(allow_output_mutation=True)
def load_model():
  model = DistilBertForSequenceClassification.from_pretrained(my_model_name)
  return model

def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def get_top_predictions(predictions):
  probs = sigmoid(predictions)
  probs = probs / np.sum(probs)

  res = {}
  total_prob = 0
  for topic, prob in sorted(zip(sorted_arxiv_topics, probs), key=lambda item: item[1], reverse=True):
    total_prob += prob
    res[topic] = prob
    if total_prob > 0.95:
      break
  return res

tokenizer = load_tokenizer()
model = load_model()

st.markdown("# Scientific paper classificator")
st.markdown(
    "Fill in paper summary and / or title below and then press open area on the page to submit inputs:",
    unsafe_allow_html=False
)

paper_title = st.text_area("Paper title")
paper_summary = st.text_area("Paper abstract")

if not paper_title and not paper_summary:
  st.markdown(f"Must have non-empty title or summary")
else:
  with torch.no_grad():
    raw_predictions = model(
      **tokenizer(
        [paper_title + "." + paper_summary],
        padding=True, truncation=True, return_tensors="pt"
      )
    )  
    results = get_top_predictions(raw_predictions[0][0].numpy())
    st.markdown("The following are probabilities for paper topics:")
    for topic, prob in sorted(results.items(), key=lambda item: item[1], reverse=True):
      st.markdown(f"{topic}: {prob}")