Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
# model part | |
import json | |
import torch | |
from torch import nn | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
with open('categories_with_names.json', 'r') as f: | |
cat_with_names = json.load(f) | |
with open('categories_from_model.json', 'r') as f: | |
categories_from_model = json.load(f) | |
def load_models_and_tokenizer(): | |
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv") | |
model_titles = AutoModelForSequenceClassification.from_pretrained( | |
"powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" | |
) | |
model_titles.eval() | |
model_abstracts = AutoModelForSequenceClassification.from_pretrained( | |
"powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" | |
) | |
model_abstracts.eval() | |
return model_titles, model_abstracts, tokenizer | |
model_titles, model_abstracts, tokenizer = load_models_and_tokenizer() | |
def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None): | |
if title is None and abstract is None: | |
raise ValueError('title is None and abstract is None') | |
models_to_run = 2 if (title is not None and abstract is not None) else 1 | |
proba_title = None | |
if title is not None: | |
progresses = (10, 30) if models_to_run == 2 else (20, 60) | |
my_bar.progress(progresses[0], text='computing titles') | |
input_tok = tokenizer(title, return_tensors='pt') | |
with torch.no_grad(): | |
logits = model_titles(**input_tok)['logits'] | |
proba_title = torch.sigmoid(logits)[0] | |
my_bar.progress(progresses[1], text='computed titles') | |
proba_abstract = None | |
if abstract is not None: | |
progresses = (40, 70) if models_to_run == 2 else (20, 60) | |
my_bar.progress(progresses[0], text='computing abstracts') | |
input_tok = tokenizer(abstract, return_tensors='pt') | |
with torch.no_grad(): | |
logits = model_abstracts(**input_tok)['logits'] | |
proba_abstract = torch.sigmoid(logits)[0] | |
my_bar.progress(progresses[0], text='computed abstracts') | |
if title is None: | |
proba = proba_abstract | |
elif abstract is None: | |
proba = proba_title | |
else: | |
proba = proba_title * 0.1 + proba_abstract * 0.9 | |
progresses = (80, 90) if models_to_run == 2 else (70, 90) | |
my_bar.progress(progresses[0], text='computed proba') | |
sorted_proba, indices = torch.sort(proba, descending=True) | |
my_bar.progress(progresses[1], text='sorted proba') | |
to_take = 1 | |
while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model): | |
to_take += 1 | |
output = [(cat_with_names[categories_from_model[index]], proba[index].item()) | |
for index in indices[:to_take]] | |
my_bar.progress(100, text='generated output') | |
return output | |
# front part | |
st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True) | |
if "title" not in st.session_state: | |
st.session_state.title = "" | |
if "abstract" not in st.session_state: | |
st.session_state.abstract = "" | |
if "title_input_key" not in st.session_state: | |
st.session_state.title_input_key = "" | |
if "abstract_input_key" not in st.session_state: | |
st.session_state.abstract_input_key = "" | |
if "model_type" not in st.session_state: | |
st.session_state.model_type = [] | |
def input_error(): | |
if not st.session_state.model_type: | |
return 'you have to select title or abstract' | |
if 'Title' in model_type and not st.session_state.title: | |
return 'Title is empty' | |
if 'Abstract' in model_type and not st.session_state.abstract: | |
return 'Abstract is empty' | |
return '' | |
def clear_input(): | |
st.session_state.title = title.title() | |
st.session_state.abstract = abstract.title() | |
if not input_error(): | |
if "Title" in st.session_state.model_type: | |
st.session_state.title_input_key = "" | |
if "Abstract" in st.session_state.model_type: | |
st.session_state.abstract_input_key = "" | |
title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key") | |
abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key") | |
model_type = st.multiselect( | |
r"$\textsf{\large Classify by:}$", | |
['Title', 'Abstract'], | |
) | |
st.session_state.model_type = model_type | |
if(st.button('Submit', on_click=clear_input)): | |
if input_error(): | |
st.error(input_error()) | |
else: | |
send_time = time.localtime(time.time()) | |
#st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}") | |
model_input = dict() | |
if 'Title' in st.session_state.model_type: | |
model_input['title'] = st.session_state.title | |
if 'Abstract' in st.session_state.model_type: | |
model_input['abstract'] = st.session_state.abstract | |
#st.success(f'{model_input=}') | |
my_bar = st.progress(0, text='starting model') | |
model_result = categorize_text(**model_input, progress_bar=my_bar) | |
st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True) | |
small_categories = [] | |
cat, proba = model_result[0] | |
st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$") | |
for cat, proba in model_result[1:]: | |
if proba < 0.1: | |
small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%") | |
else: | |
st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$") | |
if small_categories: | |
st.write(', '.join(small_categories)) | |