Spaces:
Runtime error
Runtime error
strangekitten
commited on
Commit
•
55d7b7c
1
Parent(s):
362dc22
Upload utils.py
Browse files
utils.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from .classes_desc import classes, classes_desc
|
5 |
+
|
6 |
+
|
7 |
+
def get_most_probability_terms(logits, top_k=5):
|
8 |
+
_, indices = torch.sort(logits, dim=1, descending=True)
|
9 |
+
return indices[:, :top_k]
|
10 |
+
|
11 |
+
|
12 |
+
def predict(text, model, tokenizer, my_linear, top_k=3):
|
13 |
+
tokens = tokenizer.encode(text)
|
14 |
+
with torch.no_grad():
|
15 |
+
outputs = model(torch.as_tensor([tokens]))[0]
|
16 |
+
logits = my_linear(torch.sum(outputs, dim=1))
|
17 |
+
return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
|
18 |
+
|
19 |
+
|
20 |
+
def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3):
|
21 |
+
codes = predict(text, model, tokenizer, my_linear, top_k=top_k)
|
22 |
+
answer = ["Possible text topics:"]
|
23 |
+
for code in codes:
|
24 |
+
answer += [code + ": " + classes_desc[code]]
|
25 |
+
|
26 |
+
return "\n".join(answer)
|