blunt-octopus commited on
Commit
e187312
1 Parent(s): e21e093

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -1
app.py CHANGED
@@ -1,3 +1,81 @@
1
  import streamlit as st
2
 
3
- st.markdown('### Hello world')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ from transformers import AutoTokenizer, DistilBertForSequenceClassification
4
+ import torch
5
+ from torch.nn.functional import softmax
6
+
7
+ base_model_name = 'distilbert-base-uncased'
8
+
9
+ @st.cache
10
+ def load_tags_info():
11
+ tag_to_id = {}
12
+ id_to_tag = {}
13
+ id_to_description = {}
14
+ with open('tags.txt', 'r') as file:
15
+ i = 0
16
+ for line in file:
17
+
18
+ space = line.find(' ')
19
+
20
+ tag = line[:space]
21
+ description = line[space+1:-1]
22
+
23
+ tag_to_id[tag] = i
24
+ id_to_tag[i] = tag
25
+ id_to_description[i] = description
26
+
27
+ i += 1
28
+
29
+ return (tag_to_id, id_to_tag, id_to_description)
30
+
31
+ tag_to_id, id_to_tag, id_to_description = load_tags_info()
32
+
33
+ @st.cache
34
+ def load_model():
35
+ return DistilBertForSequenceClassification.from_pretrained('./').to('cuda')
36
+
37
+ def load_tokenizer():
38
+ return AutoTokenizer.from_pretrained(base_model_name)
39
+
40
+ def top_xx(preds, xx=95):
41
+ tops = torch.argsort(preds, 1, descending=True)
42
+ total = 0
43
+ index = 0
44
+ result = []
45
+ while total < xx / 100:
46
+ next_id = tops[0, index].item()
47
+ total += preds[0, next_id]
48
+ index += 1
49
+ result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]})
50
+ return result
51
+
52
+ model = load_model()
53
+ tokenizer = load_tokenizer()
54
+ temperature = 1/2
55
+
56
+ st.title('ArXivTaxonomizer&copy; (original version)')
57
+ st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer&copy; functionality, please consider informing us.')
58
+
59
+ with st.form("Taxonomizer"):
60
+
61
+ title = st.text_area(label='Title', height=30)
62
+ abstract = st.text_area(label='Abstract (optional)', height=200)
63
+ xx = st.slider(label='Verbosity', min_value=0, max_value=100, value=95)
64
+ st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.')
65
+
66
+ submitted = st.form_submit_button("Taxonomize")
67
+ st.caption('We **do not** recommend using ArXivTaxonomizer&copy; to choose tags for you new paper.')
68
+ if submitted:
69
+ prompt = 'Title: ' + title + ' Abstract: ' + abstract
70
+ tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
71
+ preds = softmax(model(tokens.reshape(1, -1).to('cuda')).logits / temperature, dim=1)
72
+ tags = top_xx(preds, xx)
73
+ other_tags = []
74
+ st.header('Inferred tags:')
75
+ for i, tag_data in enumerate(tags):
76
+ if i < 3:
77
+ st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
78
+ if i == 2:
79
+ st.subheader('Other possible tags:')
80
+ else:
81
+ st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')