Spaces:
Runtime error
Runtime error
File size: 1,450 Bytes
626b84e a9b547e 626b84e a9b547e 626b84e 8b4be17 626b84e c9e28be a9b547e d2be1e7 626b84e 8b4be17 c9e28be a9b547e d2be1e7 971f3b2 8ca41df bdc975c 8ca41df 885d46e e985515 8b6e6a3 e985515 d52a320 c04e9dc e25d038 d47011f e985515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import streamlit as st
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
@st.cache
def create_model():
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=7)
m_state_dict = torch.load('mymodule.pt', map_location=torch.device('cpu'))
model.load_state_dict(m_state_dict)
return model
st.markdown("### arXiv paper tag classification!")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = create_model()
title = st.text_area("Enter your paper title")
summary = st.text_area("Enter your paper abstract(optional)", help="a.k.a. summary")
if st.button("Submit"):
if title == "":
st.error("Please enter your paper title")
else:
text = 'Title is ' + title + '. Abstract is ' + summary
inputs = tokenizer(text, return_tensors="pt")
results = model(**inputs)['logits']
probas = torch.nn.Softmax(dim=1)(results)
probas, indices = torch.sort(probas, descending = True)
total_proba = 0
ind = 0
arxiv_notation = ['Physics', 'Computer Science', 'Electrical Engineering and Systems Science', 'Math', 'Quantitative Biology', 'Quantitative Finance', 'Statistics']
st.markdown("Top 95% of tags for yor paper:")
while total_proba < 0.95:
st.markdown(f"* Paper tag is {arxiv_notation[indices[0][ind]]} with probability {probas[0][ind]}")
total_proba += probas[0][ind]
ind += 1 |