soltustik's picture
Update app.py
8f248f0
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
st.markdown("### Predict tag from title/abstract")
st.markdown("<img width=200px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/master/pass/faces-icon.jpg'>", unsafe_allow_html=True)
model = DistilBertForSequenceClassification.from_pretrained('.')
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model.eval()
def predict_tag(title, abstract):
text = title + ' [CLS] ' + abstract
text_encoding = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt')
with torch.no_grad():
output = model(**text_encoding)
prediction = F.softmax(output.logits, dim=1)[0]
total_prob = 0
labels = []
for prob, index in zip(*prediction.sort(descending=True)):
if (total_prob > 0.95):
break
total_prob += prob
labels.append(index.item())
labels = {model.config.id2label[label_id] : prediction[label_id].item() for label_id in labels}
return labels
title = st.text_area("TITLE HERE")
abstract= st.text_area("ABSTRACT HERE")
result_dict = predict_tag(title, abstract)
for tag in result_dict :
st.markdown(f"{tag}: {result_dict [tag] * 100:.2f}%")