import streamlit as st from torch.nn import Softmax from model import ArxivModel, load_model from tokenizer import get_tokenizer from lables import num_to_classes, taxonomy model = load_model() tokenizer = get_tokenizer() arxiv_model = ArxivModel(model, tokenizer) softmax = Softmax(dim=1) st.markdown("### Classification of article topics") title_text = st.text_area("Write title of article") summary_text = st.text_area("Write summary of article (optional)") text = title_text.strip() + " " + summary_text.strip() text = text.strip() if text != "": idxs = arxiv_model.get_idx_class(text, thr=0.95) idxs = idxs[:10] for idx, prob in idxs: for tax in taxonomy: if num_to_classes[idx] in tax[0]: st.markdown("{} \t {}%".format(tax[1], round(prob*100, 1))) break else: st.markdown("")