File size: 3,313 Bytes
51f286d
 
e187312
 
 
 
fbd05a4
 
e187312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b092bf7
 
 
 
e187312
 
 
 
 
 
35556d6
e187312
 
de4024e
e187312
 
 
 
 
 
 
 
b092bf7
092ed29
b092bf7
e187312
 
 
 
 
 
 
de4024e
e187312
 
 
 
 
 
 
 
212f56d
de4024e
e187312
 
 
 
aed7750
97967a4
49552ec
 
 
 
 
 
d1b431d
 
 
 
 
 
 
49552ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import streamlit as st

from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax



base_model_name = 'distilbert-base-uncased'

@st.cache
def load_tags_info():
    tag_to_id = {}
    id_to_tag = {}
    id_to_description = {}
    with open('tags.txt', 'r') as file:
        i = 0
        for line in file:
            
            space = line.find(' ')
            
            tag = line[:space]
            description = line[space+1:-1]
            
            tag_to_id[tag] = i
            id_to_tag[i] = tag
            id_to_description[i] = description
            
            i += 1
    
    tag_to_id['None'] = 155
    id_to_tag['155'] = 'None'
    id_to_description['155'] = 'No tag'
    
    return (tag_to_id, id_to_tag, id_to_description)

tag_to_id, id_to_tag, id_to_description = load_tags_info()

@st.cache
def load_model():
    return DistilBertForSequenceClassification.from_pretrained('./')

def load_tokenizer():
    return AutoTokenizer.from_pretrained('./')

def top_xx(preds, xx=95):
    tops = torch.argsort(preds, 1, descending=True)
    total = 0
    index = 0
    result = []
    while total < xx / 100:
        next_id = tops[0, index].item()
        if next_id == 155:
            index += 1
            continue
        total += preds[0, next_id]
        index += 1
        result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]})
    return result

model = load_model()
tokenizer = load_tokenizer()
temperature = 1

st.title('ArXivTaxonomizer&copy; (original version)')
st.caption('If you are aware of any other public services which are  illegally providing the ArXivTaxonomizer&copy; functionality, please consider informing us.')

with st.form("Taxonomizer"):
    
    title = st.text_area(label='Title', height=30)
    abstract = st.text_area(label='Abstract (optional)', height=200)
    xx = st.slider(label='Verbosity', min_value=1, max_value=99, value=95)
    st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.')

    submitted = st.form_submit_button("Taxonomize")
    st.caption('We **do not** recommend using ArXivTaxonomizer&copy; to choose tags for you new paper.')
    if submitted:
        if title == '':
            st.markdown("You are most definitely abusing our service. Have the decency to at least enter a title.")
        else:
            prompt = 'Title: ' + title + ' Abstract: ' + abstract
            tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
            preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
            tags = top_xx(preds, xx)
            other_tags = []
            st.header('Inferred tags:')
            for i, tag_data in enumerate(tags):
                if i < 3:
                    st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
                else:
                    if i == 3:
                        st.subheader('Other possible tags:')
                    st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')