import streamlit as st import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import tokenizers import transformers from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None}) def load_tok_and_model(): tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased') model = AutoModelForSequenceClassification.from_pretrained(".") return tokenizer, model tag = ['Cs', 'Econ', 'EESS', 'Math', 'Physics', 'Q-bio', 'Q-fin', 'Stat'] inv_map = {3: 'Math', 4: 'Physics', 5: 'Q-bio', 0: 'Cs', 6: 'Q-fin', 7: 'Stat', 2: 'EESS', 1: 'Econ'} @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: lambda _: None}) def predict_label(title, summary, tokenizer, model, inv_map): abstract = title.lower() + '. ' + summary.lower() token_text = tokenizer.encode(abstract) with torch.no_grad(): logits = model(torch.as_tensor([token_text]))[0] probs = torch.softmax(logits[-1, :], dim=-1).data.numpy() idx_label = np.argsort(probs)[::-1] sum_probs = 0 prediction_probs = [] prediction_classes = [] idx = 0 while sum_probs < 0.95: cur_predict = inv_map[idx_label[idx]] cur_probs = probs[idx_label[idx]] sum_probs += cur_probs prediction_probs.append(int(100 * cur_probs)) prediction_classes.append(cur_predict) idx += 1 return prediction_classes, prediction_probs, probs st.title("Classifier of possible topics of articles 📄") st.markdown("Please insert the summary and/or title of the article below") tokenizer, model = load_tok_and_model() title = st.text_area(label='Title', height=50) abstract = st.text_area(label='Summary', height=150) if st.button('Start classifier'): if title == '' and abstract == '': st.markdown("Summary and title should be filled in in the text area above") else: prediction_classes, prediction_probs, probs = predict_label(title, abstract, tokenizer, model, inv_map) data = pd.DataFrame({'Categories' : tag, 'Probs' : probs}) data = data.sort_values(by='Probs', ascending=False) fig, ax = plt.subplots() ax.bar(data['Categories'], data['Probs']) ax.bar(prediction_classes, prediction_probs) data_answer = pd.DataFrame({'Categories' : prediction_classes, 'Probs, %' : prediction_probs}) st.pyplot(fig) st.write('top-95%') st.write(data_answer)