import time import streamlit as st from torch.nn import Softmax from model import ArxivModel, load_model from tokenizer import get_tokenizer from lables import num_to_classes start_time = time.time() model = load_model() end_time = time.time() print("Model:", (end_time - start_time)) start_time = time.time() tokenizer = get_tokenizer() end_time = time.time() print("Tokenizer:", (end_time - start_time)) arxiv_model = ArxivModel(model, tokenizer) softmax = Softmax(dim=1) st.markdown("### Classification of article topics") # st.markdown("", unsafe_allow_html=True) text = st.text_area("Write title and (optional) summary of article") text = text.strip() if text != "": idxs = arxiv_model.get_idx_class(text, thr=0.95) for idx in idxs: st.markdown(num_to_classes[idx]) else: st.markdown("")