Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pickle | |
import torch | |
import numpy as np | |
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification | |
from PIL import Image | |
# @st.cache | |
def predict_topic_by_title_and_abstract(text): | |
tokenized_text = tokenizer(text, return_tensors='pt') | |
with torch.no_grad(): | |
logits = model(**tokenized_text).logits | |
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100 | |
ans = list(zip(probs,labels.values())) | |
ans.sort(reverse=True) | |
sum = 0 | |
i = 0 | |
while sum <= 95: | |
prob, label = ans[i] | |
st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%") | |
sum += prob | |
i += 1 | |
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) | |
model.load_state_dict( | |
torch.load( | |
"./trained_model" | |
) | |
) | |
image = Image.open('logo.png') | |
st.image(image) | |
st.markdown("### To get an article topic prediction, please write down it's title, abstract, or both.") | |
title = st.text_area("Write article title:", height=30) | |
abstract = st.text_area("Write article abstract:", height=60) | |
input_text = title + " " + abstract | |
if input_text != "": | |
predict_topic_by_title_and_abstract(input_text) |