wrapper228's picture
Update app.py
318a0ff
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from PIL import Image
with open('labels.pickle', 'rb') as handle:
labels = pickle.load(handle)
# @st.cache
def predict_topic_by_title_and_abstract(text):
tokens = tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
ans = list(zip(probs,labels.values()))
ans.sort(reverse=True)
sum = 0
i = 0
while sum <= 95:
prob, label = ans[i]
st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%")
sum += prob
i += 1
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model.load_state_dict(
torch.load(
"./trained_model"
)
)
image = Image.open('logo.png')
st.image(image)
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.")
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.")
st.markdown('<style>textarea { background: #E8E8E8 !important;}</style>', unsafe_allow_html=True)
title = st.text_area("Write article title:", height=30)
abstract = st.text_area("Write article abstract:", height=60)
input_text = title + " " + abstract
input_text = ''.join(filter(str.isalnum, input_text))
if len(input_text.split()) > 0:
predict_topic_by_title_and_abstract(input_text)