import streamlit as st import pandas as pd import torch from transformers import DistilBertTokenizerFast from transformers import DistilBertForSequenceClassification import os def getTop95(predictions): for i in range(len(predictions)): vals, ids = torch.topk(predictions, i) if torch.sum(vals).item() >= 0.95: return ids @st.cache(show_spinner=False) def predict(text): classes = pd.read_csv('classes.csv') tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") to_predict = title + '|' + summary X = tokenizer(to_predict, truncation=True, padding=True) tokens = torch.tensor(X['input_ids']).unsqueeze(0) mask = torch.tensor(X['attention_mask']).unsqueeze(0) model = DistilBertForSequenceClassification.from_pretrained( os.getcwd(), num_labels=len(classes) ) model.eval() logits = model(tokens, mask)[0][0] softmax = torch.nn.Softmax() predictions = softmax(logits) ids = getTop95(predictions) return classes.tag.to_numpy()[ids] st.set_page_config( page_title="ArXiv classificator", page_icon=":book:" ) st.header("Theme classification of ArXiv articles") st.markdown(""" Please enter title and summary (at least one is required) and oracul will predict classes of the arcticle according to taxonometry of ArXiv. """) with st.form(key='input_form'): title = st.text_input(label='Enter title of the article here') summary = st.text_area("Enter summary of the article here") submit = st.form_submit_button(label='Analyze') if submit: if not title and not summary: st.markdown('Please enter at least one: title or summary') else: with st.spinner(text='Oracul thinks, please wait for his wise prediction'): prediction = predict(title + '|' + summary) st.markdown("Most likely it is:") for tag in prediction[:5]: st.markdown(f"- {tag}") st.markdown("Other possible variants:") st.write(', '.join(prediction[5:])) hide_streamlit_style = """ """ st.markdown(hide_streamlit_style, unsafe_allow_html=True)