Spaces:
Runtime error
Runtime error
File size: 803 Bytes
cd5348d 55d7b7c 88eb143 55d7b7c 5fb16da 55d7b7c 0666c2a 5fb16da 55d7b7c cd5348d 55d7b7c 148d2f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import streamlit as st
import torch
import numpy as np
from classes_desc import classes, classes_desc
def get_most_probability_terms(logits, top_k=5):
_, indices = torch.sort(logits, dim=1, descending=True)
return indices[:, :top_k]
def predict(text, model, tokenizer, my_linear, top_k=10):
tokens = tokenizer.encode(text)
with torch.no_grad():
outputs = model(torch.as_tensor([tokens]))[0]
logits = my_linear(torch.sum(outputs, dim=1))
return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0]
def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=10):
codes = predict(text, model, tokenizer, my_linear)
answer = ["Possible text topics:"]
for code in codes[:top_k]:
answer += [code + ": " + classes_desc[code]]
return answer
|