import streamlit as st from transformers import DistilBertTokenizer, DistilBertModel import torch import torch.nn as nn from utils import get_answer_with_desc MY_LINEAR_NAME = "my_linear_logits_3" st.markdown("## Classifying articles on computer science!") st.markdown("", unsafe_allow_html=True) title = st.text_area("Enter the title of the article", height=20) abstract = st.text_area("Enter the abstract of the article", height=350) top_k = st.slider('How many topics from top to show?', 1, 10, 3) text = title + " " + abstract @st.cache(allow_output_mutation=True) def get_model_tokenizer_linear(): tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased") model = DistilBertModel.from_pretrained("distilbert-base-cased") n_classes = 40 my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True) my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu'))) return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear} if len(text) == 1: st.markdown("Input is empty, write something!") else: for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()): st.markdown("#### " + ms)