wrapper228's picture
Update app.py
6a8a7ab
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from PIL import Image
with open('labels.pickle', 'rb') as handle:
labels = pickle.load(handle)
# @st.cache
def predict_topic_by_title_and_abstract(text):
tokens = tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
ans = list(zip(probs,labels.values()))
ans.sort(reverse=True)
sum = 0
i = 0
while sum <= 95:
prob, label = ans[i]
st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%")
sum += prob
i += 1
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
#model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
#model.load_state_dict(
# torch.load(
# "./trained_model"
# )
#)
with open('trained_model.pickle', 'rb') as handle:
model = pickle.load(handle)
image = Image.open('logo.png')
st.image(image)
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.")
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.")
st.markdown('<style>textarea { background: #E8E8E8 !important;}</style>', unsafe_allow_html=True)
title = st.text_area("Write article title:", height=30)
abstract = st.text_area("Write article abstract:", height=60)
input_text = title + " " + abstract
input_text = ''.join(filter(str.isalnum, input_text))
if len(input_text.split()) > 0:
predict_topic_by_title_and_abstract(input_text)