articles_clf / app.py
Nikendolo's picture
Update app.py
474c845
import streamlit as st
from transformers import DistilBertModel, DistilBertTokenizer
import torch
model_path = './models/pytorch_distilbert.bin'
vocab_path = './models/vocab_distilbert.bin'
device = torch.device('cpu')
MAX_LEN = 512
labels_description = {0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'}
class DistillBERTClass(torch.nn.Module):
def __init__(self):
super(DistillBERTClass, self).__init__()
self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased")
self.pre_classifier = torch.nn.Linear(768, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 8)
def forward(self, input_ids, attention_mask):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
def predict(text, model, human_readable=True):
model.eval()
text = " ".join(text.split())
inputs = tokenizer.encode_plus(
text,
None,
add_special_tokens=True,
max_length=MAX_LEN,
pad_to_max_length=True,
return_token_type_ids=True,
truncation=True
)
ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
ids = torch.reshape(ids, (1, MAX_LEN))
mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
mask = torch.reshape(mask, (1, MAX_LEN))
with torch.no_grad():
outputs = torch.softmax(model(ids, mask), dim=-1)[0].tolist()
result = []
for i, v in enumerate(outputs):
result.append((v, i))
result.sort(reverse=True)
pr = 0.0
index = 0
answer = []
while pr < 0.95:
pr += result[index][0]
if not human_readable:
answer.append(result[index][1])
else:
answer.append(labels_description[result[index][1]] + " - {:.2f}%".format(100 * result[index][0]))
index += 1
return answer
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model_and_tokenizer():
return (torch.load(model_path, map_location=torch.device(device)),
DistilBertTokenizer.from_pretrained(vocab_path))
model, tokenizer = load_model_and_tokenizer()
st.markdown("### Hi! This is a service for determining the subject of an article.")
st.markdown("It can predict the following topics:\n"
"* Computer Science\n"
"* Economics\n"
"* Electrical Engineering and Systems Science\n"
"* Mathematics\n"
"* Physics\n"
"* Quantitative Biology\n"
"* Quantitative Finance\n"
"* Statistics\n")
st.markdown("#### Just write the title and abstract in the areas below and click the \"Analyze\" button.")
title = st.text_area("Title")
abstract = st.text_area("Abstract")
if st.button('Analyze'):
with st.spinner("Wait..."):
if not title and not abstract:
st.error(f"You haven't written anything.")
elif not title:
st.error(f"You haven't written a title.")
else:
pred = predict(title+"\n"+abstract, model.to(device))
st.success("\n\n".join(pred))