import streamlit as st import torch import numpy as np from classes_desc import classes, classes_desc def get_most_probability_terms(logits, top_k=5): _, indices = torch.sort(logits, dim=1, descending=True) return indices[:, :top_k] @st.cache(suppress_st_warning=True) def predict(text, model, tokenizer, my_linear, top_k=3): tokens = tokenizer.encode(text) with torch.no_grad(): outputs = model(torch.as_tensor([tokens]))[0] logits = my_linear(torch.sum(outputs, dim=1)) return np.array(classes)[get_most_probability_terms(logits, top_k).cpu().numpy()][0] def get_answer_with_desc(text, model, tokenizer, my_linear, top_k=3): codes = predict(text, model, tokenizer, my_linear, top_k=10) answer = ["Possible text topics:"] for code in codes[:top_k]: answer += [code + ": " + classes_desc[code]] return answer