ysda_hw / app.py
planetearth79's picture
Update app.py
fbdbabe
raw
history blame contribute delete
No virus
3.54 kB
import streamlit as st
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import torch
import pandas as pd
import numpy as np
st.markdown("# Arxiv Papers Classifier")
st.markdown("<img width=200px src='https://blog.arxiv.org/files/2021/02/arxiv-logo.svg'>", unsafe_allow_html=True)
st.markdown("После обработки и фильтрации датасета у каждой статьи остался один или несколько классов из 9:")
st.markdown("""
1) ai - cs.AI (Artificial Intelligence)
2) cs - все подгруппы из класса Computer Science, кроме cs.AI, cs.CV, cs.LG
3) cv - cs.CV (Computer Vision and Pattern Recognition)
4) lg - cs.LG (Machine Learning)
5) math - все подгруппы из класса Mathematics
6) ml - stat.ML (Machine Learning)
7) phys - все подгруппы из класса Physics
8) q-bio - все подгруппы из класса Quantitative Biology
9) stat - все подгруппы из класса Statistics, кроме stat.ML
""")
id2label = {
0: "ai",
1: "cs",
2: "cv",
3: "lg",
4: "math",
5: "ml",
6: "phys",
7: "q-bio",
8: "stat"
}
title_text = st.text_input("ENTER TITLE HERE")
summary_text = st.text_area("ENTER SUMMARY HERE")
text = title_text + " " + summary_text
# 1
@st.cache
def load_first_model():
loaded_tokenizer = AutoTokenizer.from_pretrained("multi_class_model")
loaded_model = AutoModelForSequenceClassification.from_pretrained("multi_class_model")
return loaded_tokenizer, loaded_model
tokenizer_1, model_1 = load_first_model()
# loaded_tokenizer = AutoTokenizer.from_pretrained("multi_class_model")
# loaded_model = AutoModelForSequenceClassification.from_pretrained("multi_class_model")
st.markdown("## multi-class classification")
text_input = tokenizer_1(text, padding="max_length", truncation=True, return_tensors='pt')
with torch.no_grad():
text_res = model_1(**text_input)
text_probs = torch.softmax(text_res.logits, dim=1).cpu().numpy()[0]
order = np.argsort(text_probs)[::-1]
ordered_text_probs = text_probs[order]
idxs = order[np.cumsum(ordered_text_probs) <= 0.95]
st.markdown("Топ-95 классов: " + ", ".join([id2label[i] for i in idxs]))
chart_data = pd.DataFrame(
text_probs,
columns=['class probability'])
chart_data["index"] = np.array(list(id2label.values()))
chart_data = chart_data.set_index("index")
st.bar_chart(chart_data)
# 2
@st.cache
def load_first_model():
loaded_tokenizer = AutoTokenizer.from_pretrained("multi_label_model")
loaded_model = AutoModelForSequenceClassification.from_pretrained("multi_label_model")
return loaded_tokenizer, loaded_model
tokenizer_2, model_2 = load_first_model()
# loaded_tokenizer = AutoTokenizer.from_pretrained("multi_label_model")
# loaded_model = AutoModelForSequenceClassification.from_pretrained("multi_label_model")
st.markdown("## multi-label classification")
text_input = tokenizer_2(text, padding="max_length", truncation=True, return_tensors='pt')
with torch.no_grad():
text_res = model_2(**text_input)
text_probs = torch.sigmoid(torch.Tensor(text_res.logits)).cpu().numpy()[0]
probs = np.stack([text_probs, 1 - text_probs], axis=1)
chart_data = pd.DataFrame(
probs,
columns=['belong', "not belong"])
chart_data["index"] = np.array(list(id2label.values()))
chart_data = chart_data.set_index("index")
st.markdown("Probabilities for each class")
st.bar_chart(chart_data)