File size: 2,943 Bytes
51f286d
 
e187312
 
 
 
fbd05a4
 
e187312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35556d6
e187312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212f56d
 
e187312
 
 
 
 
 
35556d6
e187312
 
 
 
 
 
 
353f1f9
 
e187312
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st

from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax



base_model_name = 'distilbert-base-uncased'

@st.cache
def load_tags_info():
    tag_to_id = {}
    id_to_tag = {}
    id_to_description = {}
    with open('tags.txt', 'r') as file:
        i = 0
        for line in file:
            
            space = line.find(' ')
            
            tag = line[:space]
            description = line[space+1:-1]
            
            tag_to_id[tag] = i
            id_to_tag[i] = tag
            id_to_description[i] = description
            
            i += 1
    
    return (tag_to_id, id_to_tag, id_to_description)

tag_to_id, id_to_tag, id_to_description = load_tags_info()

@st.cache
def load_model():
    return DistilBertForSequenceClassification.from_pretrained('./')

def load_tokenizer():
    return AutoTokenizer.from_pretrained(base_model_name)

def top_xx(preds, xx=95):
    tops = torch.argsort(preds, 1, descending=True)
    total = 0
    index = 0
    result = []
    while total < xx / 100:
        next_id = tops[0, index].item()
        total += preds[0, next_id]
        index += 1
        result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]})
    return result

model = load_model()
tokenizer = load_tokenizer()
temperature = 1/2

st.title('ArXivTaxonomizer&copy; (original version)')
st.caption('If you are aware of any other public services which are  illegally providing the ArXivTaxonomizer&copy; functionality, please consider informing us.')

with st.form("Taxonomizer"):
    
    title = st.text_area(label='Title', height=30)
    abstract = st.text_area(label='Abstract (optional)', height=200)
    xx = st.slider(label='Verbosity', min_value=1, max_value=99, value=95)
    st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 95 work best and generate a list of 3-5 guesses.')

    submitted = st.form_submit_button("Taxonomize")
    st.caption('We **do not** recommend using ArXivTaxonomizer&copy; to choose tags for you new paper.')
    if submitted:
        prompt = 'Title: ' + title + ' Abstract: ' + abstract
        tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
        preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
        tags = top_xx(preds, xx)
        other_tags = []
        st.header('Inferred tags:')
        for i, tag_data in enumerate(tags):
            if i < 3:
                st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
            else:
                if i == 3:
                    st.subheader('Other possible tags:')
                st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')