lab2 / app.py
IlyaUsmanov's picture
Update app.py
e25d038
import torch
import streamlit as st
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
@st.cache
def create_model():
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=7)
m_state_dict = torch.load('mymodule.pt', map_location=torch.device('cpu'))
model.load_state_dict(m_state_dict)
return model
st.markdown("### arXiv paper tag classification!")
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = create_model()
title = st.text_area("Enter your paper title")
summary = st.text_area("Enter your paper abstract(optional)", help="a.k.a. summary")
if st.button("Submit"):
if title == "":
st.error("Please enter your paper title")
else:
text = 'Title is ' + title + '. Abstract is ' + summary
inputs = tokenizer(text, return_tensors="pt")
results = model(**inputs)['logits']
probas = torch.nn.Softmax(dim=1)(results)
probas, indices = torch.sort(probas, descending = True)
total_proba = 0
ind = 0
arxiv_notation = ['Physics', 'Computer Science', 'Electrical Engineering and Systems Science', 'Math', 'Quantitative Biology', 'Quantitative Finance', 'Statistics']
st.markdown("Top 95% of tags for yor paper:")
while total_proba < 0.95:
st.markdown(f"* Paper tag is {arxiv_notation[indices[0][ind]]} with probability {probas[0][ind]}")
total_proba += probas[0][ind]
ind += 1