Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer | |
def create_model(): | |
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=7) | |
m_state_dict = torch.load('mymodule.pt', map_location=torch.device('cpu')) | |
model.load_state_dict(m_state_dict) | |
return model | |
st.markdown("### arXiv paper tag classification!") | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
model = create_model() | |
title = st.text_area("Enter your paper title") | |
summary = st.text_area("Enter your paper abstract(optional)", help="a.k.a. summary") | |
if st.button("Submit"): | |
if title == "": | |
st.error("Please enter your paper title") | |
else: | |
text = 'Title is ' + title + '. Abstract is ' + summary | |
inputs = tokenizer(text, return_tensors="pt") | |
results = model(**inputs)['logits'] | |
probas = torch.nn.Softmax(dim=1)(results) | |
probas, indices = torch.sort(probas, descending = True) | |
total_proba = 0 | |
ind = 0 | |
arxiv_notation = ['Physics', 'Computer Science', 'Electrical Engineering and Systems Science', 'Math', 'Quantitative Biology', 'Quantitative Finance', 'Statistics'] | |
st.markdown("Top 95% of tags for yor paper:") | |
while total_proba < 0.95: | |
st.markdown(f"* Paper tag is {arxiv_notation[indices[0][ind]]} with probability {probas[0][ind]}") | |
total_proba += probas[0][ind] | |
ind += 1 |