File size: 1,450 Bytes
626b84e
a9b547e
626b84e
a9b547e
626b84e
8b4be17
 
626b84e
 
c9e28be
a9b547e
d2be1e7
626b84e
8b4be17
c9e28be
a9b547e
d2be1e7
 
971f3b2
8ca41df
bdc975c
8ca41df
 
 
885d46e
e985515
8b6e6a3
e985515
 
 
d52a320
c04e9dc
e25d038
d47011f
e985515
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import streamlit as st
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer

@st.cache
def create_model():
  model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=7)
  m_state_dict = torch.load('mymodule.pt', map_location=torch.device('cpu'))
  model.load_state_dict(m_state_dict)
  return model

st.markdown("### arXiv paper tag classification!")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')

model = create_model()

title = st.text_area("Enter your paper title")
summary = st.text_area("Enter your paper abstract(optional)", help="a.k.a. summary")
if st.button("Submit"):
  if title == "":
    st.error("Please enter your paper title")
  else:
    text = 'Title is ' + title + '. Abstract is ' + summary
    inputs = tokenizer(text, return_tensors="pt")
    results = model(**inputs)['logits']
    probas = torch.nn.Softmax(dim=1)(results)
    probas, indices = torch.sort(probas, descending = True)
    total_proba = 0
    ind = 0
    arxiv_notation = ['Physics', 'Computer Science', 'Electrical Engineering and Systems Science', 'Math', 'Quantitative Biology', 'Quantitative Finance', 'Statistics']
    st.markdown("Top 95% of tags for yor paper:")
    while total_proba < 0.95:
      st.markdown(f"* Paper tag is {arxiv_notation[indices[0][ind]]} with probability {probas[0][ind]}")
      total_proba += probas[0][ind]
      ind += 1