import streamlit as st from torch.nn import Softmax from model import ArxivModel, load_model from tokenizer import get_tokenizer from lables import num_to_classes, taxonomy from parser import get_text_title model = load_model() tokenizer = get_tokenizer() arxiv_model = ArxivModel(model, tokenizer) softmax = Softmax(dim=1) st.markdown("### Classification of article topics") col1, col2 = st.columns(2) text = "" with col1: title_text = st.text_area("Write title of article", key='arxiv_title_input') with col2: summary_text = st.text_area("Write summary of article (optional)", key='arxiv_sum_input') click_button_text = st.button('Submit title and summary', key=1) if click_button_text and summary_text.strip() != "": text = title_text.strip() + '\t' + summary_text.strip() else: text = title_text.strip() text = text.strip() id_url = st.text_input("Write article's url or id", key='arxiv_id_input').strip() click_button_url = st.button('Submit id', key=1) if click_button_url and id_url != "": res = get_text_title(id_url) if res is not None: text = res[0].strip() + '\t' + res[1].strip() text = text.strip() else: st.markdown(f'

Incorrect url or id

', unsafe_allow_html=True) text = "" print(text) if text != "": idxs = arxiv_model.get_idx_class(text, thr=0.95)[:10] for idx, prob in idxs: if taxonomy.get(num_to_classes[idx], -1) != -1: st.markdown("{} \t {}%".format(taxonomy.get(num_to_classes[idx], -1), round(prob * 100, 1))) else: st.markdown("")