import torch from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast import streamlit as st def get_text(title: str, abstract: str): if abstract and title: text = abstract + ' ' + title elif title: text = title elif abstract: text = abstract else: text = None return text def get_labels(text, model, tokenizer, count_labels=8): tokens = tokenizer(text, return_tensors='pt') outputs = model(**tokens) probs = torch.nn.Softmax()(outputs.logits) labels = ['Computer_science', 'Economics', 'Electrical_Engineering_and_Systems_Science', 'Mathematics', 'Physics', 'Quantitative_Biology', 'Quantitative_Finance', 'Statistics'] sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0]) cumsum = 0 result_labels = [] for pair in sort_lst: cumsum += pair[0] if cumsum > 0.95: result_labels.append(pair[1]) return result_labels result_labels.append(pair[1]) @st.cache(allow_output_mutation=True) def load_model(): tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased") model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8) model.load_state_dict(torch.load('weight_model')) return model, tokenizer