strangekitten commited on
Commit
cd5348d
1 Parent(s): 0645c16

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +5 -3
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import torch
2
  import numpy as np
3
 
@@ -8,7 +10,7 @@ def get_most_probability_terms(logits, top_k=5):
8
  _, indices = torch.sort(logits, dim=1, descending=True)
9
  return indices[:, :top_k]
10
 
11
-
12
  def predict(text, model, tokenizer, my_linear, top_k=3):
13
  tokens = tokenizer.encode(text)
14
  with torch.no_grad():
@@ -18,9 +20,9 @@ def predict(text, model, tokenizer, my_linear, top_k=3):
18
 
19
 
20
  def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
21
- codes = predict(text, model, tokenizer, my_linear, top_k=top_k)
22
  answer = ["Possible text topics:"]
23
- for code in codes:
24
  answer += [code + ": " + classes_desc[code]]
25
 
26
  return answer
 
1
+ import streamlit as st
2
+
3
  import torch
4
  import numpy as np
5
 
 
10
  _, indices = torch.sort(logits, dim=1, descending=True)
11
  return indices[:, :top_k]
12
 
13
+ @st.cache(suppress_st_warning=True)
14
  def predict(text, model, tokenizer, my_linear, top_k=3):
15
  tokens = tokenizer.encode(text)
16
  with torch.no_grad():
 
20
 
21
 
22
  def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
23
+ codes = predict(text, model, tokenizer, my_linear, top_k=10)
24
  answer = ["Possible text topics:"]
25
+ for code in codes[:top_k]:
26
  answer += [code + ": " + classes_desc[code]]
27
 
28
  return answer