strangekitten's picture
Update utils.py
0666c2a
import streamlit as st
import torch
import numpy as np
from classes_desc import classes, classes_desc
def get_most_probability_terms(logits, top_k=5):
_, indices = torch.sort(logits, dim=1, descending=True)
return indices[:, :top_k]
def predict(text, model, tokenizer, my_linear, top_k=10):
tokens = tokenizer.encode(text)
with torch.no_grad():
outputs = model(torch.as_tensor([tokens]))[0]
logits = my_linear(torch.sum(outputs, dim=1))
return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=10):
codes = predict(text, model, tokenizer, my_linear)
answer = ["Possible text topics:"]
for code in codes[:top_k]:
answer += [code + ": " + classes_desc[code]]
return answer