shad_ml2 / app.py
Darkhan's picture
Update app.py
6f6d6c2
raw
history blame
No virus
1.32 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_NAME = 'bert-base-cased'
MODEL_PATH = 'bert_model'
ID2CLS = {
0: 'Computer Science',
1: 'Economics',
2: 'Electrical Engineering and Systems Science',
3: 'Mathematics',
4: 'Physics',
5: 'Quantitative Biology',
6: 'Quantitative Finance',
7: 'Statistics'
}
def classify(text, tokenizer, model):
if not text:
return [""]
tokens = tokenizer([text], truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids']
probabilities = torch.softmax(model(tokens).logits, dim=1).detach().cpu().numpy()[0]
total = 0
ans = []
for p in probabilities.argsort()[::-1]:
if probabilities[p] + total < 0.9:
total += probabilities[p]
ans += [f'{ID2CLS[p]}: {round(probabilities[p] * 100, 2)}%']
return ans
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=8)
# model.load_state_dict(torch.load(model_path))
model.eval()
st.markdown("## Article classifier")
title = st.text_area("title")
text = st.text_area("article")
for prediction in classify(title + text, tokenizer, model):
st.markdown(prediction)