Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import torch.nn.functional as F | |
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification | |
st.markdown("### Predict tag from title/abstract") | |
st.markdown("<img width=200px src='https://media.wired.com/photos/592700e3cfe0d93c474320f1/master/pass/faces-icon.jpg'>", unsafe_allow_html=True) | |
model = DistilBertForSequenceClassification.from_pretrained('.') | |
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') | |
model.eval() | |
def predict_tag(title, abstract): | |
text = title + ' [CLS] ' + abstract | |
text_encoding = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt') | |
with torch.no_grad(): | |
output = model(**text_encoding) | |
prediction = F.softmax(output.logits, dim=1)[0] | |
total_prob = 0 | |
labels = [] | |
for prob, index in zip(*prediction.sort(descending=True)): | |
if (total_prob > 0.95): | |
break | |
total_prob += prob | |
labels.append(index.item()) | |
labels = {model.config.id2label[label_id] : prediction[label_id].item() for label_id in labels} | |
return labels | |
title = st.text_area("TITLE HERE") | |
abstract= st.text_area("ABSTRACT HERE") | |
result_dict = predict_tag(title, abstract) | |
for tag in result_dict : | |
st.markdown(f"{tag}: {result_dict [tag] * 100:.2f}%") |