File size: 5,845 Bytes
c713558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import streamlit as st
import time

# model part

import json
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification

with open('categories_with_names.json', 'r') as f:
    cat_with_names = json.load(f)
with open('categories_from_model.json', 'r') as f:
    categories_from_model = json.load(f)

@st.cache_resource
def load_models_and_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("oracat/bert-paper-classifier-arxiv")

    model_titles = AutoModelForSequenceClassification.from_pretrained(
        "powerful_model_titles/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
    )
    model_titles.eval()
    model_abstracts = AutoModelForSequenceClassification.from_pretrained(
        "powerful_model_abstracts/checkpoint-13472", num_labels=len(categories_from_model), problem_type="multi_label_classification"
    )
    model_abstracts.eval()
    
    return model_titles, model_abstracts, tokenizer

model_titles, model_abstracts, tokenizer = load_models_and_tokenizer()

def categorize_text(title: str | None = None, abstract: str | None = None, progress_bar = None):
    if title is None and abstract is None:
        raise ValueError('title is None and abstract is None')

    models_to_run = 2 if (title is not None and abstract is not None) else 1
        
    proba_title = None
    if title is not None:
        progresses = (10, 30) if models_to_run == 2 else (20, 60)
        my_bar.progress(progresses[0], text='computing titles')
        input_tok = tokenizer(title, return_tensors='pt')
        with torch.no_grad():
            logits = model_titles(**input_tok)['logits']
            proba_title = torch.sigmoid(logits)[0]
        my_bar.progress(progresses[1], text='computed titles')
            
    proba_abstract = None
    if abstract is not None:
        progresses = (40, 70) if models_to_run == 2 else (20, 60)
        my_bar.progress(progresses[0], text='computing abstracts')
        input_tok = tokenizer(abstract, return_tensors='pt')
        with torch.no_grad():
            logits = model_abstracts(**input_tok)['logits']
            proba_abstract = torch.sigmoid(logits)[0]
        my_bar.progress(progresses[0], text='computed abstracts')
            
    if title is None:
        proba = proba_abstract
    elif abstract is None:
        proba = proba_title
    else:
        proba = proba_title * 0.1 + proba_abstract * 0.9

    progresses = (80, 90) if models_to_run == 2 else (70, 90)

    my_bar.progress(progresses[0], text='computed proba')
    
    sorted_proba, indices = torch.sort(proba, descending=True)
    my_bar.progress(progresses[1], text='sorted proba')
    to_take = 1
    while sorted_proba[:to_take].sum() < 0.95 and to_take < len(categories_from_model):
        to_take += 1
    output = [(cat_with_names[categories_from_model[index]], proba[index].item())
            for index in indices[:to_take]]
    my_bar.progress(100, text='generated output')
    return output

# front part

st.markdown("<h1 style='text-align: center;'>Classify your paper!</h1>", unsafe_allow_html=True)


if "title" not in st.session_state:
    st.session_state.title = ""
if "abstract" not in st.session_state:
    st.session_state.abstract = ""
if "title_input_key" not in st.session_state:
    st.session_state.title_input_key = ""
if "abstract_input_key" not in st.session_state:
    st.session_state.abstract_input_key = ""
if "model_type" not in st.session_state:
    st.session_state.model_type = []

def input_error():
    if not st.session_state.model_type:
        return 'you have to select title or abstract'
    if 'Title' in model_type and not st.session_state.title:
        return 'Title is empty'
    if 'Abstract' in model_type and not st.session_state.abstract:
        return 'Abstract is empty'
    return ''


def clear_input():
    st.session_state.title = title.title()
    st.session_state.abstract = abstract.title()
    if not input_error():
        if "Title" in st.session_state.model_type:
           st.session_state.title_input_key = ""
        if "Abstract" in st.session_state.model_type:
           st.session_state.abstract_input_key = ""


title = st.text_input(r"$\textsf{\Large Title}$", key="title_input_key")

abstract = st.text_input(r"$\textsf{\Large Abstract}$", key="abstract_input_key")


model_type = st.multiselect(
    r"$\textsf{\large Classify by:}$",
    ['Title', 'Abstract'],
)

st.session_state.model_type = model_type

if(st.button('Submit', on_click=clear_input)):
    if input_error():
        st.error(input_error())
    else:
        send_time = time.localtime(time.time())
        #st.success(f"Submitted {(' and '.join(st.session_state.model_type)).lower()} on {time.strftime('%d.%m %H:%M:%S', send_time)}")
        model_input = dict()
        if 'Title' in st.session_state.model_type:
            model_input['title'] = st.session_state.title
        if 'Abstract' in st.session_state.model_type:
            model_input['abstract'] = st.session_state.abstract
        #st.success(f'{model_input=}')
        my_bar = st.progress(0, text='starting model')
        model_result = categorize_text(**model_input, progress_bar=my_bar)
        st.markdown("<h1 style='text-align: center;'>Classification completed!</h1>", unsafe_allow_html=True)
        small_categories = []
        cat, proba = model_result[0]
        st.write(r"$\textsf{\Large " + f'{cat}: {round(100*proba)}' + r"\%}$")
        for cat, proba in model_result[1:]:
            if proba < 0.1:
                small_categories.append(f'{cat}: {round(100*proba, 1)}' + r"\%")
            else:
                st.write(r"$\textsf{\large " + f'{cat}: {round(100*proba)}' + r"\%}$")
        if small_categories:
            st.write(', '.join(small_categories))