TovaHasi's picture
Update app.py
43a8fb9
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import tokenizers
import transformers
from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def load_tok_and_model():
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained(".")
return tokenizer, model
tag = ['Cs', 'Econ', 'EESS', 'Math', 'Physics', 'Q-bio', 'Q-fin', 'Stat']
inv_map = {3: 'Math', 4: 'Physics', 5: 'Q-bio', 0: 'Cs', 6: 'Q-fin', 7: 'Stat', 2: 'EESS', 1: 'Econ'}
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None})
def predict_label(title, summary, tokenizer, model, inv_map):
abstract = title.lower() + '. ' + summary.lower()
token_text = tokenizer.encode(abstract)
with torch.no_grad():
logits = model(torch.as_tensor([token_text]))[0]
probs = torch.softmax(logits[-1, :], dim=-1).data.numpy()
idx_label = np.argsort(probs)[::-1]
sum_probs = 0
prediction_probs = []
prediction_classes = []
idx = 0
while sum_probs < 0.95:
cur_predict = inv_map[idx_label[idx]]
cur_probs = probs[idx_label[idx]]
sum_probs += cur_probs
prediction_probs.append(int(100 * cur_probs))
prediction_classes.append(cur_predict)
idx += 1
return prediction_classes, prediction_probs, probs
st.title("Classifier of possible topics of articles πŸ“„")
st.markdown("Please insert the summary and/or title of the article below")
tokenizer, model = load_tok_and_model()
title = st.text_area(label='Title', height=50)
abstract = st.text_area(label='Summary', height=150)
if st.button('Start classifier'):
if title == '' and abstract == '':
st.markdown("Summary and title should be filled in in the text area above")
else:
prediction_classes, prediction_probs, probs = predict_label(title, abstract, tokenizer, model, inv_map)
data = pd.DataFrame({'Categories' : tag, 'Probs' : probs})
data = data.sort_values(by='Probs', ascending=False)
fig, ax = plt.subplots()
ax.bar(data['Categories'], data['Probs'])
ax.bar(prediction_classes, prediction_probs)
data_answer = pd.DataFrame({'Categories' : prediction_classes, 'Probs, %' : prediction_probs})
st.pyplot(fig)
st.write('top-95%')
st.write(data_answer)