HW4_YSDA / app.py
MariaUDmitrieva's picture
Upload 7 files
c713558 verified
import streamlit as st
import time
# model part
import json
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
with open('categories_with_names.json', 'r') as f:
cat_with_names = json.load(f)
with open('categories_from_model.json', 'r') as f:
categories_from_model = json.load(f)
@st.cache_resource
def load_models_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
model_titles = AutoModelForSequenceClassification.from_pretrained(
"powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
)
model_titles.eval()
model_abstracts = AutoModelForSequenceClassification.from_pretrained(
"powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
)
model_abstracts.eval()
return model_titles, model_abstracts, tokenizer
model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()
def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
if title is None and abstract is None:
raise ValueError('title is None and abstract is None')
models_to_run = 2 if (title is not None and abstract is not None) else 1
proba_title = None
if title is not None:
progresses = (10, 30) if models_to_run == 2 else (20, 60)
my_bar.progress(progresses[0], text='computing titles')
input_tok = tokenizer(title, return_tensors='pt')
with torch.no_grad():
logits = model_titles(**input_tok)['logits']
proba_title = torch.sigmoid(logits)[0]
my_bar.progress(progresses[1], text='computed titles')
proba_abstract = None
if abstract is not None:
progresses = (40, 70) if models_to_run == 2 else (20, 60)
my_bar.progress(progresses[0], text='computing abstracts')
input_tok = tokenizer(abstract, return_tensors='pt')
with torch.no_grad():
logits = model_abstracts(**input_tok)['logits']
proba_abstract = torch.sigmoid(logits)[0]
my_bar.progress(progresses[0], text='computed abstracts')
if title is None:
proba = proba_abstract
elif abstract is None:
proba = proba_title
else:
proba = proba_title * 0.1 + proba_abstract * 0.9
progresses = (80, 90) if models_to_run == 2 else (70, 90)
my_bar.progress(progresses[0], text='computed proba')
sorted_proba, indices = torch.sort(proba, descending=True)
my_bar.progress(progresses[1], text='sorted proba')
to_take = 1
while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
to_take += 1
output = [(cat_with_names[categories_from_model[index]], proba[index].item())
for index in indices[:to_take]]
my_bar.progress(100, text='generated output')
return output
# front part
st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)
if "title" not in st.session_state:
st.session_state.title = ""
if "abstract" not in st.session_state:
st.session_state.abstract = ""
if "title_input_key" not in st.session_state:
st.session_state.title_input_key = ""
if "abstract_input_key" not in st.session_state:
st.session_state.abstract_input_key = ""
if "model_type" not in st.session_state:
st.session_state.model_type = []
def input_error():
if not st.session_state.model_type:
return 'you have to select title or abstract'
if 'Title' in model_type and not st.session_state.title:
return 'Title is empty'
if 'Abstract' in model_type and not st.session_state.abstract:
return 'Abstract is empty'
return ''
def clear_input():
st.session_state.title = title.title()
st.session_state.abstract = abstract.title()
if not input_error():
if "Title" in st.session_state.model_type:
st.session_state.title_input_key = ""
if "Abstract" in st.session_state.model_type:
st.session_state.abstract_input_key = ""
title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")
abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")
model_type = st.multiselect(
r"$\textsf{\large Classify by:}$",
['Title', 'Abstract'],
)
st.session_state.model_type = model_type
if(st.button('Submit', on_click=clear_input)):
if input_error():
st.error(input_error())
else:
send_time = time.localtime(time.time())
#st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
model_input = dict()
if 'Title' in st.session_state.model_type:
model_input['title'] = st.session_state.title
if 'Abstract' in st.session_state.model_type:
model_input['abstract'] = st.session_state.abstract
#st.success(f'{model_input=}')
my_bar = st.progress(0, text='starting model')
model_result = categorize_text(**model_input, progress_bar=my_bar)
st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
small_categories = []
cat, proba = model_result[0]
st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
for cat, proba in model_result[1:]:
if proba < 0.1:
small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
else:
st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
if small_categories:
st.write(', '.join(small_categories))