Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import numpy as np | |
from classes_desc import classes, classes_desc | |
def get_most_probability_terms(logits, top_k=5): | |
_, indices = torch.sort(logits, dim=1, descending=True) | |
return indices[:, :top_k] | |
def predict(text, model, tokenizer, my_linear, top_k=10): | |
tokens = tokenizer.encode(text) | |
with torch.no_grad(): | |
outputs = model(torch.as_tensor([tokens]))[0] | |
logits = my_linear(torch.sum(outputs, dim=1)) | |
return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0] | |
def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=10): | |
codes = predict(text, model, tokenizer, my_linear) | |
answer = ["Possible text topics:"] | |
for code in codes[:top_k]: | |
answer += [code + ": " + classes_desc[code]] | |
return answer | |