Spaces:
Runtime error
Runtime error
File size: 3,345 Bytes
51f286d e187312 fbd05a4 e187312 b092bf7 e187312 35556d6 e187312 de4024e e187312 b092bf7 092ed29 b092bf7 e187312 de4024e e187312 212f56d de4024e e187312 d1b431d aed7750 d1b431d e187312 35556d6 e187312 d1b431d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
from transformers import AutoTokenizer, DistilBertForSequenceClassification
import torch
from torch.nn.functional import softmax
base_model_name = 'distilbert-base-uncased'
@st.cache
def load_tags_info():
tag_to_id = {}
id_to_tag = {}
id_to_description = {}
with open('tags.txt', 'r') as file:
i = 0
for line in file:
space = line.find(' ')
tag = line[:space]
description = line[space+1:-1]
tag_to_id[tag] = i
id_to_tag[i] = tag
id_to_description[i] = description
i += 1
tag_to_id['None'] = 155
id_to_tag['155'] = 'None'
id_to_description['155'] = 'No tag'
return (tag_to_id, id_to_tag, id_to_description)
tag_to_id, id_to_tag, id_to_description = load_tags_info()
@st.cache
def load_model():
return DistilBertForSequenceClassification.from_pretrained('./')
def load_tokenizer():
return AutoTokenizer.from_pretrained('./')
def top_xx(preds, xx=95):
tops = torch.argsort(preds, 1, descending=True)
total = 0
index = 0
result = []
while total < xx / 100:
next_id = tops[0, index].item()
if next_id == 155:
index += 1
continue
total += preds[0, next_id]
index += 1
result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]})
return result
model = load_model()
tokenizer = load_tokenizer()
temperature = 1
st.title('ArXivTaxonomizer© (original version)')
st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer© functionality, please consider informing us.')
with st.form("Taxonomizer"):
title = st.text_area(label='Title', height=30)
abstract = st.text_area(label='Abstract (optional)', height=200)
xx = st.slider(label='Verbosity', min_value=1, max_value=99, value=95)
st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.')
submitted = st.form_submit_button("Taxonomize")
st.caption('We **do not** recommend using ArXivTaxonomizer© to choose tags for you new paper.')
if submitted:
empty = False
if title == '':
st.markdown("You are most definitely abusing our service. No ticket today, but you better be careful.")
empty = True
prompt = 'Title: ' + title + ' Abstract: ' + abstract
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids']
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1)
tags = top_xx(preds, xx)
other_tags = []
if not empty:
st.header('Inferred tags:')
for i, tag_data in enumerate(tags):
if i < 3:
st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
else:
if i == 3:
st.subheader('Other possible tags:')
st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')
|