wrapper228's picture
Update app.py
a3e66f4
raw history blame
No virus
1.87 kB
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from PIL import Image
labels = {
"0":"biology",
"1":"computer science",
"2":"economics",
"3":"electrics",
"4":"finance",
"5":"math",
"6":"physics",
"7":"statistics"
}
def predict_topic_by_title_and_abstract(text, model):
tokens = tokenizer(text, return_tensors='pt', truncation=True)
with torch.no_grad():
logits = model(**tokens).logits
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
ans = list(zip(probs,labels.values()))
ans.sort(reverse=True)
sum = 0
i = 0
while sum <= 95:
prob, label = ans[i]
st.write("it's topic \"" + str(label) + "\" with probability "+ str(np.round(prob,1)) + "%")
sum += prob
i += 1
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
with open('trained_model.pickle', 'rb') as handle:
model = pickle.load(handle)
image = Image.open('logo.png')
st.image(image)
st.markdown("##### This app predicts the probabilities of the article belonging to the following topics: \'biology\', \'computer science\', \'economics\', \'electrics\', \'finance\', \'math\', \'physics\', \'statistics\'.")
st.markdown("##### To get an article topic prediction, please write down it's title, abstract, or both.")
st.markdown('<style>textarea { background: #E8E8E8 !important;}</style>', unsafe_allow_html=True)
title = st.text_area("Write article title:", height=30)
abstract = st.text_area("Write article abstract:", height=60)
input_text = title + " " + abstract
input_text = ''.join(filter(str.isalnum, input_text))
if len(input_text.split()) > 0:
predict_topic_by_title_and_abstract(input_text, model)