AleksBlacky's picture
change ui, added readable topics names
3722795
import streamlit as st
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class ArxivClassifierModelsPipeline():
def __init__(self):
self.model_topic_clf = self.__load_topic_clf()
self.model_maintopic_clf = self.__load_maintopic_clf()
topic_clf_default_model = "allenai/scibert_scivocab_uncased"
self.topic_tokenizer = AutoTokenizer.from_pretrained(topic_clf_default_model)
maintopic_clf_default_model = "Wi/arxiv-topics-distilbert-base-cased"
self.maintopic_tokenizer = AutoTokenizer.from_pretrained(maintopic_clf_default_model)
with open('models/scibert/decode_dict_topic.pkl', 'rb') as f:
self.decode_dict_topic = pickle.load(f)
with open('models/maintopic_clf/decode_dict_maintopic.pkl', 'rb') as f:
self.decode_dict_maintopic = pickle.load(f)
with open('models/maintopic_clf/main_topic_dict.pkl', 'rb') as f:
self.main_topic_dict = pickle.load(f)
with open('models/scibert/topic_dict.pkl', 'rb') as f:
self.topic_dict = pickle.load(f)
def make_predict(self, text):
tokens_topic = self.topic_tokenizer(text, return_tensors="pt")
topic_outs = self.model_topic_clf(tokens_topic.input_ids)
probs_topic = topic_outs["logits"].softmax(dim=-1).tolist()[0]
topic_probs = {}
for i, p in enumerate(probs_topic):
if p > 0.1:
if self.decode_dict_topic[i] in self.topic_dict:
topic_probs[self.topic_dict[self.decode_dict_topic[i]]] = p
else:
topic_probs[self.decode_dict_topic[i]] = p
tokens_maintopic = self.maintopic_tokenizer(text, return_tensors="pt")
maintopic_outs = self.model_maintopic_clf(tokens_maintopic.input_ids)
probs_maintopic = maintopic_outs["logits"].softmax(dim=-1).tolist()[0]
maintopic_probs = self.decode_dict_maintopic[0]
return topic_probs, self.main_topic_dict[maintopic_probs]
@st.cache(suppress_st_warning=True)
def __load_topic_clf(self):
st.write("Loading model")
return AutoModelForSequenceClassification.from_pretrained("models/scibert/")
@st.cache(suppress_st_warning=True)
def __load_maintopic_clf(self):
st.write("Loading second model")
return AutoModelForSequenceClassification.from_pretrained("models/maintopic_clf/")