Spaces:
Runtime error
Runtime error
Commit
·
e187312
1
Parent(s):
e21e093
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,81 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
+
from transformers import AutoTokenizer, DistilBertForSequenceClassification
|
4 |
+
import torch
|
5 |
+
from torch.nn.functional import softmax
|
6 |
+
|
7 |
+
base_model_name = 'distilbert-base-uncased'
|
8 |
+
|
9 |
+
@st.cache
|
10 |
+
def load_tags_info():
|
11 |
+
tag_to_id = {}
|
12 |
+
id_to_tag = {}
|
13 |
+
id_to_description = {}
|
14 |
+
with open('tags.txt', 'r') as file:
|
15 |
+
i = 0
|
16 |
+
for line in file:
|
17 |
+
|
18 |
+
space = line.find(' ')
|
19 |
+
|
20 |
+
tag = line[:space]
|
21 |
+
description = line[space+1:-1]
|
22 |
+
|
23 |
+
tag_to_id[tag] = i
|
24 |
+
id_to_tag[i] = tag
|
25 |
+
id_to_description[i] = description
|
26 |
+
|
27 |
+
i += 1
|
28 |
+
|
29 |
+
return (tag_to_id, id_to_tag, id_to_description)
|
30 |
+
|
31 |
+
tag_to_id, id_to_tag, id_to_description = load_tags_info()
|
32 |
+
|
33 |
+
@st.cache
|
34 |
+
def load_model():
|
35 |
+
return DistilBertForSequenceClassification.from_pretrained('./').to('cuda')
|
36 |
+
|
37 |
+
def load_tokenizer():
|
38 |
+
return AutoTokenizer.from_pretrained(base_model_name)
|
39 |
+
|
40 |
+
def top_xx(preds, xx=95):
|
41 |
+
tops = torch.argsort(preds, 1, descending=True)
|
42 |
+
total = 0
|
43 |
+
index = 0
|
44 |
+
result = []
|
45 |
+
while total < xx / 100:
|
46 |
+
next_id = tops[0, index].item()
|
47 |
+
total += preds[0, next_id]
|
48 |
+
index += 1
|
49 |
+
result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]})
|
50 |
+
return result
|
51 |
+
|
52 |
+
model = load_model()
|
53 |
+
tokenizer = load_tokenizer()
|
54 |
+
temperature = 1/2
|
55 |
+
|
56 |
+
st.title('ArXivTaxonomizer© (original version)')
|
57 |
+
st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer© functionality, please consider informing us.')
|
58 |
+
|
59 |
+
with st.form("Taxonomizer"):
|
60 |
+
|
61 |
+
title = st.text_area(label='Title', height=30)
|
62 |
+
abstract = st.text_area(label='Abstract (optional)', height=200)
|
63 |
+
xx = st.slider(label='Verbosity', min_value=0, max_value=100, value=95)
|
64 |
+
st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.')
|
65 |
+
|
66 |
+
submitted = st.form_submit_button("Taxonomize")
|
67 |
+
st.caption('We **do not** recommend using ArXivTaxonomizer© to choose tags for you new paper.')
|
68 |
+
if submitted:
|
69 |
+
prompt = 'Title: ' + title + ' Abstract: ' + abstract
|
70 |
+
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
|
71 |
+
preds = softmax(model(tokens.reshape(1, -1).to('cuda')).logits / temperature, dim=1)
|
72 |
+
tags = top_xx(preds, xx)
|
73 |
+
other_tags = []
|
74 |
+
st.header('Inferred tags:')
|
75 |
+
for i, tag_data in enumerate(tags):
|
76 |
+
if i < 3:
|
77 |
+
st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
|
78 |
+
if i == 2:
|
79 |
+
st.subheader('Other possible tags:')
|
80 |
+
else:
|
81 |
+
st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
|