strangekitten commited on
Commit
6163ea2
1 Parent(s): 0666c2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -25,9 +25,13 @@ def get_model_tokenizer_linear():
25
  my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
26
  my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
27
  return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear}
28
-
 
 
 
 
29
  if len(text) == 1:
30
  st.markdown("Input is empty, write something!")
31
  else:
32
- for ms in get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear()):
33
  st.markdown("#### " + ms)
 
25
  my_linear = nn.Linear(in_features=768, out_features=n_classes, bias=True)
26
  my_linear.load_state_dict(torch.load(MY_LINEAR_NAME, map_location=torch.device('cpu')))
27
  return {"model": model, "tokenizer": tokenizer, "my_linear": my_linear}
28
+
29
+ @st.cache()
30
+ def predict_topics(top_k):
31
+ return get_answer_with_desc(text, top_k=top_k, **get_model_tokenizer_linear())[:top_k + 1]
32
+
33
  if len(text) == 1:
34
  st.markdown("Input is empty, write something!")
35
  else:
36
+ for ms in predict_topics(top_k):
37
  st.markdown("#### " + ms)