Spaces:
Sleeping
Sleeping
File size: 5,845 Bytes
c713558 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import streamlit as st
import time
# model part
import json
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification
with open('categories_with_names.json', 'r') as f:
cat_with_names = json.load(f)
with open('categories_from_model.json', 'r') as f:
categories_from_model = json.load(f)
@st.cache_resource
def load_models_and_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")
model_titles = AutoModelForSequenceClassification.from_pretrained(
"powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
)
model_titles.eval()
model_abstracts = AutoModelForSequenceClassification.from_pretrained(
"powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
)
model_abstracts.eval()
return model_titles, model_abstracts, tokenizer
model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()
def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
if title is None and abstract is None:
raise ValueError('title is None and abstract is None')
models_to_run = 2 if (title is not None and abstract is not None) else 1
proba_title = None
if title is not None:
progresses = (10, 30) if models_to_run == 2 else (20, 60)
my_bar.progress(progresses[0], text='computing titles')
input_tok = tokenizer(title, return_tensors='pt')
with torch.no_grad():
logits = model_titles(**input_tok)['logits']
proba_title = torch.sigmoid(logits)[0]
my_bar.progress(progresses[1], text='computed titles')
proba_abstract = None
if abstract is not None:
progresses = (40, 70) if models_to_run == 2 else (20, 60)
my_bar.progress(progresses[0], text='computing abstracts')
input_tok = tokenizer(abstract, return_tensors='pt')
with torch.no_grad():
logits = model_abstracts(**input_tok)['logits']
proba_abstract = torch.sigmoid(logits)[0]
my_bar.progress(progresses[0], text='computed abstracts')
if title is None:
proba = proba_abstract
elif abstract is None:
proba = proba_title
else:
proba = proba_title * 0.1 + proba_abstract * 0.9
progresses = (80, 90) if models_to_run == 2 else (70, 90)
my_bar.progress(progresses[0], text='computed proba')
sorted_proba, indices = torch.sort(proba, descending=True)
my_bar.progress(progresses[1], text='sorted proba')
to_take = 1
while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
to_take += 1
output = [(cat_with_names[categories_from_model[index]], proba[index].item())
for index in indices[:to_take]]
my_bar.progress(100, text='generated output')
return output
# front part
st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)
if "title" not in st.session_state:
st.session_state.title = ""
if "abstract" not in st.session_state:
st.session_state.abstract = ""
if "title_input_key" not in st.session_state:
st.session_state.title_input_key = ""
if "abstract_input_key" not in st.session_state:
st.session_state.abstract_input_key = ""
if "model_type" not in st.session_state:
st.session_state.model_type = []
def input_error():
if not st.session_state.model_type:
return 'you have to select title or abstract'
if 'Title' in model_type and not st.session_state.title:
return 'Title is empty'
if 'Abstract' in model_type and not st.session_state.abstract:
return 'Abstract is empty'
return ''
def clear_input():
st.session_state.title = title.title()
st.session_state.abstract = abstract.title()
if not input_error():
if "Title" in st.session_state.model_type:
st.session_state.title_input_key = ""
if "Abstract" in st.session_state.model_type:
st.session_state.abstract_input_key = ""
title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")
abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")
model_type = st.multiselect(
r"$\textsf{\large Classify by:}$",
['Title', 'Abstract'],
)
st.session_state.model_type = model_type
if(st.button('Submit', on_click=clear_input)):
if input_error():
st.error(input_error())
else:
send_time = time.localtime(time.time())
#st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
model_input = dict()
if 'Title' in st.session_state.model_type:
model_input['title'] = st.session_state.title
if 'Abstract' in st.session_state.model_type:
model_input['abstract'] = st.session_state.abstract
#st.success(f'{model_input=}')
my_bar = st.progress(0, text='starting model')
model_result = categorize_text(**model_input, progress_bar=my_bar)
st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
small_categories = []
cat, proba = model_result[0]
st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
for cat, proba in model_result[1:]:
if proba < 0.1:
small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
else:
st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
if small_categories:
st.write(', '.join(small_categories))
|