wrapper228's picture
Update app.py
7e51dc1
raw
history blame
1.39 kB
import streamlit as st
import pickle
import torch
import numpy as np
from transformers import TrainingArguments, Trainer, AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from PIL import Image
# @st.cache
def predict_topic_by_title_and_abstract(text):
tokenized_text = tokenizer(text, return_tensors='pt')
with torch.no_grad():
logits = model(**tokenized_text).logits
probs = torch.nn.functional.softmax(logits[0], dim=0).numpy() * 100
ans = list(zip(probs,labels.values()))
ans.sort(reverse=True)
sum = 0
i = 0
while sum <= 95:
prob, label = ans[i]
st.write("it's topic \"" + label + "\" with probability "+ str(np.round(prob,1)) + "%")
sum += prob
i += 1
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
model.load_state_dict(
torch.load(
"./trained_model"
)
)
image = Image.open('logo.png')
st.image(image)
st.markdown("### To get an article topic prediction, please write down it's title, abstract, or both.")
title = st.text_area("Write article title:", height=30)
abstract = st.text_area("Write article abstract:", height=60)
input_text = title + " " + abstract
if input_text != "":
predict_topic_by_title_and_abstract(input_text)