strangekitten's picture
Update app.py
7c706d6
raw
history blame contribute delete
No virus
1.37 kB
import streamlit as st
from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn as nn
from utils import get_answer_with_desc
MY_LINEAR_NAME = "my_linear_logits_3"
st.markdown("## Classifying articles on computer science!")
st.markdown("<img width=400px src='https://img.freepik.com/free-photo/young-pretty-student-overwhelmed-with-books_272645-183.jpg?size=626&ext=jpg'>", unsafe_allow_html=True)
title = st.text_area("Enter the title of the article", height=20)
abstract = st.text_area("Enter the abstract of the article", height=350)
top_k = st.slider('How many topics from top to show?', 1, 10, 3)
text = title + " " + abstract
@st.cache(allow_output_mutation=True)
def get_model_tokenizer_linear():
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
model = DistilBertModel.from_pretrained("distilbert-base-cased")
n_classes = 40
my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear}
if len(text) == 1:
st.markdown("Input is empty, write something!")
else:
for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()):
st.markdown("#### " + ms)