strangekitten commited on
Commit
5fb16da
1 Parent(s): 9b1dcbc

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +3 -2
utils.py CHANGED
@@ -11,7 +11,7 @@ def get_most_probability_terms(logits, top_k=5):
11
  return indices[:, :top_k]
12
 
13
  @st.cache(suppress_st_warning=True)
14
- def predict(text, model, tokenizer, my_linear, top_k=3):
15
  tokens = tokenizer.encode(text)
16
  with torch.no_grad():
17
  outputs = model(torch.as_tensor([tokens]))[0]
@@ -19,8 +19,9 @@ def predict(text, model, tokenizer, my_linear, top_k=3):
19
  return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
20
 
21
 
 
22
  def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
23
- codes = predict(text, model, tokenizer, my_linear, top_k=10)
24
  answer = ["Possible text topics:"]
25
  for code in codes[:top_k]:
26
  answer += [code + ": " + classes_desc[code]]
 
11
  return indices[:, :top_k]
12
 
13
  @st.cache(suppress_st_warning=True)
14
+ def predict(text, model, tokenizer, my_linear, top_k=10):
15
  tokens = tokenizer.encode(text)
16
  with torch.no_grad():
17
  outputs = model(torch.as_tensor([tokens]))[0]
 
19
  return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
20
 
21
 
22
+ @st.cache(suppress_st_warning=True)
23
  def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
24
+ codes = predict(text, model, tokenizer, my_linear)
25
  answer = ["Possible text topics:"]
26
  for code in codes[:top_k]:
27
  answer += [code + ": " + classes_desc[code]]