Dzhamb's picture
Update utils.py
547522b
raw
history blame contribute delete
No virus
1.38 kB
import torch
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
import streamlit as st
def get_text(title: str, abstract: str):
if abstract and title:
text = abstract + ' ' + title
elif title:
text = title
elif abstract:
text = abstract
else:
text = None
return text
def get_labels(text, model, tokenizer, count_labels=8):
tokens = tokenizer(text, return_tensors='pt')
outputs = model(**tokens)
probs = torch.nn.Softmax()(outputs.logits)
labels = ['Computer_science', 'Economics',
'Electrical_Engineering_and_Systems_Science', 'Mathematics',
'Physics', 'Quantitative_Biology', 'Quantitative_Finance',
'Statistics']
sort_lst = sorted([(prob, label) for prob, label in zip(probs.detach().numpy()[0], labels)], key=lambda x: -x[0])
cumsum = 0
result_labels = []
for pair in sort_lst:
cumsum += pair[0]
if cumsum > 0.95:
result_labels.append(pair[1])
return result_labels
result_labels.append(pair[1])
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8)
model.load_state_dict(torch.load('weight_model'))
return model, tokenizer