|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
|
|
st.markdown("### A dummy site for classifying article topics by title and abstract.") |
|
st.markdown("It can predict the following topics: Computer Science, Economics, Electrical Engineering and Systems Science, Mathematics, Quantitative Biology, Quantitative Finance, Statistics, Physics") |
|
|
|
|
|
from transformers import pipeline |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
@st.cache(suppress_st_warning=True) |
|
def model_tokenizer(): |
|
model_name = 'distilbert-base-cased' |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=8, problem_type="multi_label_classification") |
|
weights = torch.load('model.pt', map_location=torch.device('cpu')) |
|
model.load_state_dict(weights) |
|
return model |
|
|
|
def make_prediction(model, tokenizer, text): |
|
|
|
tokens = tokenizer.encode(text) |
|
with torch.no_grad(): |
|
logits = model.cpu()(torch.as_tensor([tokens]))[0] |
|
|
|
probs = np.array(torch.softmax(logits[-1, :], dim=-1)) |
|
|
|
|
|
sorted_classes, sorted_probs = np.flip(np.argsort(probs)), sorted(probs, reverse=True) |
|
prediction_classes, prediction_probs = [], [] |
|
probs_sum = 0 |
|
i=0 |
|
res = [] |
|
while probs_sum <= 0.95: |
|
|
|
|
|
|
|
|
|
|
|
prediction_classes.append(to_category[sorted_classes[i]]) |
|
prediction_probs.append(100*sorted_probs[i]) |
|
probs_sum += sorted_probs[i] |
|
i += 1 |
|
for pr, cl in zip(prediction_probs, prediction_classes): |
|
print(str("{:.2f}".format(pr) + "%"), cl) |
|
res.append((str("{:.2f}".format(pr) + "%"), cl)) |
|
return res |
|
|
|
model = model_tokenizer() |
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased", problem_type="multi_label_classification") |
|
|
|
categories_full = ['Computer Science', 'Economics', 'Electrical Engineering and Systems Science', 'Mathematics', 'Quantitative Biology', 'Quantitative Finance', 'Statistics', 'Physics'] |
|
|
|
to_category = {} |
|
|
|
for i in range(len(categories_full)): |
|
to_category[i] = categories_full[i] |
|
|
|
|
|
|
|
title = st.text_area("Type the title of the article here") |
|
abstract = st.text_area("Type the abstract of the article here") |
|
|
|
|
|
if st.button('Analyse'): |
|
if title or abstract: |
|
text = '[TITLE] ' + title + ' [ABSTRACT] ' + abstract |
|
res = make_prediction(model, tokenizer, text) |
|
for cat in res: |
|
st.markdown(f"{cat[0], cat[1]}") |
|
else: |
|
st.error(f"Write title or abstract") |
|
|
|
|
|
|