Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pickle | |
import torch | |
import numpy as np | |
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification | |
from PIL import Image | |
with open('labels.pickle', 'rb') as handle: | |
labels = pickle.load(handle) | |
# @st.cache | |
def predict_topic_by_title_and_abstract(text): | |
tokens = tokenizer(text, return_tensors='pt', truncation=True) | |
with torch.no_grad(): | |
logits = model(**tokens).logits | |
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100 | |
ans = list(zip(probs,labels.values())) | |
ans.sort(reverse=True) | |
sum = 0 | |
i = 0 | |
while sum <= 95: | |
prob, label = ans[i] | |
st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%") | |
sum += prob | |
i += 1 | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) | |
model.load_state_dict( | |
torch.load( | |
"./trained_model" | |
) | |
) | |
image = Image.open('logo.png') | |
st.image(image) | |
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.") | |
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.") | |
st.markdown('<style>textarea { background: gray !important;}</style>', unsafe_allow_html=True) | |
title = st.text_area("Write article title:", height=30, unsafe_allow_html=True) | |
abstract = st.text_area("Write article abstract:", height=60, unsafe_allow_html=True) | |
input_text = title + " " + abstract | |
if input_text != " ": | |
predict_topic_by_title_and_abstract(input_text) |