File size: 1,362 Bytes
02ea321
 
362dc22
 
 
 
89899dd
362dc22
 
 
02ea321
 
 
8b2c4ec
 
0645c16
362dc22
 
9b1dcbc
bb5f62e
 
 
 
 
 
 
 
 
3793f45
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import streamlit as st

from transformers import DistilBertTokenizer, DistilBertModel
import torch
import torch.nn as nn

from utils import get_answer_with_desc

MY_LINEAR_NAME = "my_linear_logits_3"

st.markdown("## Classifying articles on computer science!")
st.markdown("<img width=400px src='https://img.freepik.com/free-photo/young-pretty-student-overwhelmed-with-books_272645-183.jpg?size=626&ext=jpg'>", unsafe_allow_html=True)

title = st.text_area("Enter the title of the article", height=20)
abstract = st.text_area("Enter the abstract of the article", height=350)
top_k = st.slider('How many topics from top to show?', 1, 10, 3)
text = title + " " + abstract

@st.cache(allow_output_mutation=True)
def get_model_tokenizer_linear():
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
    model = DistilBertModel.from_pretrained("distilbert-base-cased")
    
    n_classes = 40
    my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
    my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
    return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear}

if len(text) == 1:
      st.markdown("Input is empty, write something!")  
else:
    for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()):
        st.markdown("#### " + ms)