File size: 1,354 Bytes
fbb2b39 6f6d6c2 fbb2b39 6f6d6c2 45c713e 6f6d6c2 45c713e 6f6d6c2 45c713e 6f6d6c2 5c59066 6f6d6c2 45c713e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_NAME = 'bert-base-cased'
MODEL_PATH = 'bert_model'
ID2CLS = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'
}
def classify(text, tokenizer, model):
if not text:
return [""]
batch = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")
outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
probabilities = torch.softmax(outputs.logits, dim=1).detach().cpu().numpy()[0]
total = 0
for p in probabilities.argsort()[::-1]:
field = f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)} %'
st.markdown(field)
total += probabilities[p]
if total > 0.95:
break
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
st.markdown("## Article classifier")
title = st.text_area("title")
text = st.text_area("article")
classify(title + text, tokenizer, model)
|