strangekitten commited on
Commit
bb5f62e
·
1 Parent(s): cd5348d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -16,12 +16,15 @@ abstract = st.text_area("Enter the abstract of the article", height=350)
16
  top_k = st.slider('How many topics from top to show?', 1, 10, 3)
17
  text = title + " " + abstract
18
 
19
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
20
- model = DistilBertModel.from_pretrained("distilbert-base-cased")
21
-
22
- n_classes = 40
23
- my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
24
- my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
25
-
26
- for ms in get_answer_with_desc(text, model, tokenizer, my_linear, top_k=top_k):
 
 
 
27
  st.markdown("#### " + ms)
 
16
  top_k = st.slider('How many topics from top to show?', 1, 10, 3)
17
  text = title + " " + abstract
18
 
19
+ @st.cache(suppress_st_warning=True)
20
+ def get_model_tokenizer_linear():
21
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-cased")
22
+ model = DistilBertModel.from_pretrained("distilbert-base-cased")
23
+
24
+ n_classes = 40
25
+ my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
26
+ my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
27
+ return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear}
28
+
29
+ for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()):
30
  st.markdown("#### " + ms)