strangekitten commited on
Commit
55d7b7c
1 Parent(s): 362dc22

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +26 -0
utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from .classes_desc import classes, classes_desc
5
+
6
+
7
+ def get_most_probability_terms(logits, top_k=5):
8
+ _, indices = torch.sort(logits, dim=1, descending=True)
9
+ return indices[:, :top_k]
10
+
11
+
12
+ def predict(text, model, tokenizer, my_linear, top_k=3):
13
+ tokens = tokenizer.encode(text)
14
+ with torch.no_grad():
15
+ outputs = model(torch.as_tensor([tokens]))[0]
16
+ logits = my_linear(torch.sum(outputs, dim=1))
17
+ return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
18
+
19
+
20
+ def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
21
+ codes = predict(text, model, tokenizer, my_linear, top_k=top_k)
22
+ answer = ["Possible text topics:"]
23
+ for code in codes:
24
+ answer += [code + ": " + classes_desc[code]]
25
+
26
+ return "\n".join(answer)