File size: 1,348 Bytes
fbb2b39 6f6d6c2 fbb2b39 6f6d6c2 5c59066 6f6d6c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_NAME = 'bert-base-cased'
MODEL_PATH = 'bert_model'
ID2CLS = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'
}
def classify(text, tokenizer, model):
if not text:
return [""]
tokens = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids']
probabilities = torch.softmax(model(tokens).logits, dim=1).detach().cpu().numpy()[0]
total = 0
ans = []
for p in probabilities.argsort()[::-1]:
if probabilities[p] + total < 0.9:
total += probabilities[p]
ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%']
return ans
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
st.markdown("## Article classifier")
title = st.text_area("title")
text = st.text_area("article")
for prediction in classify(title + text, tokenizer, model):
st.markdown(prediction)
|