shad_ml2 / app.py
Darkhan's picture
Update app.py
45c713e
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_NAME = 'bert-base-cased'
MODEL_PATH = 'bert_model'
ID2CLS = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'
}
def classify(text, tokenizer, model):
if not text:
return [""]
batch = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")
outputs = model(batch['input_ids'], attention_mask=batch['attention_mask'])
probabilities = torch.softmax(outputs.logits, dim=1).detach().cpu().numpy()[0]
total = 0
for p in probabilities.argsort()[::-1]:
field = f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)} %'
st.markdown(field)
total += probabilities[p]
if total > 0.95:
break
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
model.eval()
st.markdown("## Article classifier")
title = st.text_area("title")
text = st.text_area("article")
classify(title + text, tokenizer, model)