File size: 803 Bytes
cd5348d
 
55d7b7c
 
 
88eb143
55d7b7c
 
 
 
 
 
5fb16da
55d7b7c
 
 
 
 
 
 
0666c2a
5fb16da
55d7b7c
cd5348d
55d7b7c
 
148d2f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import streamlit as st

import torch
import numpy as np

from classes_desc import classes, classes_desc


def get_most_probability_terms(logits, top_k=5):
  _, indices = torch.sort(logits, dim=1, descending=True)
  return indices[:, :top_k]

def predict(text, model, tokenizer, my_linear, top_k=10):
  tokens = tokenizer.encode(text)
  with torch.no_grad():
    outputs = model(torch.as_tensor([tokens]))[0]
    logits = my_linear(torch.sum(outputs, dim=1))
  return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]


def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=10):
  codes = predict(text, model, tokenizer, my_linear)
  answer = ["Possible text topics:"]
  for code in codes[:top_k]:
    answer += [code + ": " + classes_desc[code]]

  return answer