|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
MODEL_NAME = 'bert-base-cased' |
|
MODEL_PATH = 'bert_model' |
|
|
|
ID2CLS = { |
|
0: 'Computer Science', |
|
1: 'Economics', |
|
2: 'Electrical Engineering and Systems Science', |
|
3: 'Mathematics', |
|
4: 'Physics', |
|
5: 'Quantitative Biology', |
|
6: 'Quantitative Finance', |
|
7: 'Statistics' |
|
} |
|
|
|
|
|
def classify(text, tokenizer, model): |
|
if not text: |
|
return [""] |
|
tokens = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids'] |
|
probabilities = torch.softmax(model(tokens).logits, dim=1).detach().cpu().numpy()[0] |
|
total = 0 |
|
ans = [] |
|
|
|
for p in probabilities.argsort()[::-1]: |
|
if probabilities[p] + total < 0.9: |
|
total += probabilities[p] |
|
ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%'] |
|
|
|
return ans |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8) |
|
|
|
model.eval() |
|
|
|
st.markdown("## Article classifier") |
|
|
|
title = st.text_area("title") |
|
text = st.text_area("article") |
|
|
|
for prediction in classify(title + text, tokenizer, model): |
|
st.markdown(prediction) |
|
|