import streamlit as st import torch import torch.nn.functional as F from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification st.markdown("### Predict tag from title/abstract") st.markdown("", unsafe_allow_html=True) model = DistilBertForSequenceClassification.from_pretrained('.') tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased') model.eval() def predict_tag(title, abstract): text = title + ' [CLS] ' + abstract text_encoding = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt') with torch.no_grad(): output = model(**text_encoding) prediction = F.softmax(output.logits, dim=1)[0] total_prob = 0 labels = [] for prob, index in zip(*prediction.sort(descending=True)): if (total_prob > 0.95): break total_prob += prob labels.append(index.item()) labels = {model.config.id2label[label_id] : prediction[label_id].item() for label_id in labels} return labels title = st.text_area("TITLE HERE") abstract= st.text_area("ABSTRACT HERE") result_dict = predict_tag(title, abstract) for tag in result_dict : st.markdown(f"{tag}: {result_dict [tag] * 100:.2f}%")