Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import time | |
| # model part | |
| import json | |
| import torch | |
| from torch import nn | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| with open('categories_with_names.json', 'r') as f: | |
| cat_with_names = json.load(f) | |
| with open('categories_from_model.json', 'r') as f: | |
| categories_from_model = json.load(f) | |
| def load_models_and_tokenizer(): | |
| tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv") | |
| model_titles = AutoModelForSequenceClassification.from_pretrained( | |
| "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" | |
| ) | |
| model_titles.eval() | |
| model_abstracts = AutoModelForSequenceClassification.from_pretrained( | |
| "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification" | |
| ) | |
| model_abstracts.eval() | |
| return model_titles, model_abstracts, tokenizer | |
| model_titles, model_abstracts, tokenizer = load_models_and_tokenizer() | |
| def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None): | |
| if title is None and abstract is None: | |
| raise ValueError('title is None and abstract is None') | |
| models_to_run = 2 if (title is not None and abstract is not None) else 1 | |
| proba_title = None | |
| if title is not None: | |
| progresses = (10, 30) if models_to_run == 2 else (20, 60) | |
| my_bar.progress(progresses[0], text='computing titles') | |
| input_tok = tokenizer(title, return_tensors='pt') | |
| with torch.no_grad(): | |
| logits = model_titles(**input_tok)['logits'] | |
| proba_title = torch.sigmoid(logits)[0] | |
| my_bar.progress(progresses[1], text='computed titles') | |
| proba_abstract = None | |
| if abstract is not None: | |
| progresses = (40, 70) if models_to_run == 2 else (20, 60) | |
| my_bar.progress(progresses[0], text='computing abstracts') | |
| input_tok = tokenizer(abstract, return_tensors='pt') | |
| with torch.no_grad(): | |
| logits = model_abstracts(**input_tok)['logits'] | |
| proba_abstract = torch.sigmoid(logits)[0] | |
| my_bar.progress(progresses[0], text='computed abstracts') | |
| if title is None: | |
| proba = proba_abstract | |
| elif abstract is None: | |
| proba = proba_title | |
| else: | |
| proba = proba_title * 0.1 + proba_abstract * 0.9 | |
| progresses = (80, 90) if models_to_run == 2 else (70, 90) | |
| my_bar.progress(progresses[0], text='computed proba') | |
| sorted_proba, indices = torch.sort(proba, descending=True) | |
| my_bar.progress(progresses[1], text='sorted proba') | |
| to_take = 1 | |
| while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model): | |
| to_take += 1 | |
| output = [(cat_with_names[categories_from_model[index]], proba[index].item()) | |
| for index in indices[:to_take]] | |
| my_bar.progress(100, text='generated output') | |
| return output | |
| # front part | |
| st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True) | |
| if "title" not in st.session_state: | |
| st.session_state.title = "" | |
| if "abstract" not in st.session_state: | |
| st.session_state.abstract = "" | |
| if "title_input_key" not in st.session_state: | |
| st.session_state.title_input_key = "" | |
| if "abstract_input_key" not in st.session_state: | |
| st.session_state.abstract_input_key = "" | |
| if "model_type" not in st.session_state: | |
| st.session_state.model_type = [] | |
| def input_error(): | |
| if not st.session_state.model_type: | |
| return 'you have to select title or abstract' | |
| if 'Title' in model_type and not st.session_state.title: | |
| return 'Title is empty' | |
| if 'Abstract' in model_type and not st.session_state.abstract: | |
| return 'Abstract is empty' | |
| return '' | |
| def clear_input(): | |
| st.session_state.title = title.title() | |
| st.session_state.abstract = abstract.title() | |
| if not input_error(): | |
| if "Title" in st.session_state.model_type: | |
| st.session_state.title_input_key = "" | |
| if "Abstract" in st.session_state.model_type: | |
| st.session_state.abstract_input_key = "" | |
| title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key") | |
| abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key") | |
| model_type = st.multiselect( | |
| r"$\textsf{\large Classify by:}$", | |
| ['Title', 'Abstract'], | |
| ) | |
| st.session_state.model_type = model_type | |
| if(st.button('Submit', on_click=clear_input)): | |
| if input_error(): | |
| st.error(input_error()) | |
| else: | |
| send_time = time.localtime(time.time()) | |
| #st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}") | |
| model_input = dict() | |
| if 'Title' in st.session_state.model_type: | |
| model_input['title'] = st.session_state.title | |
| if 'Abstract' in st.session_state.model_type: | |
| model_input['abstract'] = st.session_state.abstract | |
| #st.success(f'{model_input=}') | |
| my_bar = st.progress(0, text='starting model') | |
| model_result = categorize_text(**model_input, progress_bar=my_bar) | |
| st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True) | |
| small_categories = [] | |
| cat, proba = model_result[0] | |
| st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$") | |
| for cat, proba in model_result[1:]: | |
| if proba < 0.1: | |
| small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%") | |
| else: | |
| st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$") | |
| if small_categories: | |
| st.write(', '.join(small_categories)) | |