Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import DistilBertModel, DistilBertTokenizer | |
import torch | |
model_path = './models/pytorch_distilbert.bin' | |
vocab_path = './models/vocab_distilbert.bin' | |
device = torch.device('cpu') | |
MAX_LEN = 512 | |
labels_description = {0: 'Computer Science', | |
1: 'Economics', | |
2: 'Electrical Engineering and Systems Science', | |
3: 'Mathematics', | |
4: 'Physics', | |
5: 'Quantitative Biology', | |
6: 'Quantitative Finance', | |
7: 'Statistics'} | |
class DistillBERTClass(torch.nn.Module): | |
def __init__(self): | |
super(DistillBERTClass, self).__init__() | |
self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased") | |
self.pre_classifier = torch.nn.Linear(768, 768) | |
self.dropout = torch.nn.Dropout(0.3) | |
self.classifier = torch.nn.Linear(768, 8) | |
def forward(self, input_ids, attention_mask): | |
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) | |
hidden_state = output_1[0] | |
pooler = hidden_state[:, 0] | |
pooler = self.pre_classifier(pooler) | |
pooler = torch.nn.ReLU()(pooler) | |
pooler = self.dropout(pooler) | |
output = self.classifier(pooler) | |
return output | |
def predict(text, model, human_readable=True): | |
model.eval() | |
text = " ".join(text.split()) | |
inputs = tokenizer.encode_plus( | |
text, | |
None, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
pad_to_max_length=True, | |
return_token_type_ids=True, | |
truncation=True | |
) | |
ids = torch.tensor(inputs['input_ids'], dtype=torch.long) | |
ids = torch.reshape(ids, (1, MAX_LEN)) | |
mask = torch.tensor(inputs['attention_mask'], dtype=torch.long) | |
mask = torch.reshape(mask, (1, MAX_LEN)) | |
with torch.no_grad(): | |
outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist() | |
result = [] | |
for i, v in enumerate(outputs): | |
result.append((v, i)) | |
result.sort(reverse=True) | |
pr = 0.0 | |
index = 0 | |
answer = [] | |
while pr < 0.95: | |
pr += result[index][0] | |
if not human_readable: | |
answer.append(result[index][1]) | |
else: | |
answer.append(labels_description[result[index][1]] + " - {:.2f}%".format(100 * result[index][0])) | |
index += 1 | |
return answer | |
def load_model_and_tokenizer(): | |
return (torch.load(model_path, map_location=torch.device(device)), | |
DistilBertTokenizer.from_pretrained(vocab_path)) | |
model, tokenizer = load_model_and_tokenizer() | |
st.markdown("### Hi! This is a service for determining the subject of an article.") | |
st.markdown("It can predict the following topics:\n" | |
"* Computer Science\n" | |
"* Economics\n" | |
"* Electrical Engineering and Systems Science\n" | |
"* Mathematics\n" | |
"* Physics\n" | |
"* Quantitative Biology\n" | |
"* Quantitative Finance\n" | |
"* Statistics\n") | |
st.markdown("#### Just write the title and abstract in the areas below and click the \"Analyze\" button.") | |
title = st.text_area("Title") | |
abstract = st.text_area("Abstract") | |
if st.button('Analyze'): | |
with st.spinner("Wait..."): | |
if not title and not abstract: | |
st.error(f"You haven't written anything.") | |
elif not title: | |
st.error(f"You haven't written a title.") | |
else: | |
pred = predict(title+"\n"+abstract, model.to(device)) | |
st.success("\n\n".join(pred)) |