Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoTokenizer, DistilBertForSequenceClassification | |
import torch | |
from torch.nn.functional import softmax | |
base_model_name = 'distilbert-base-uncased' | |
def load_tags_info(): | |
tag_to_id = {} | |
id_to_tag = {} | |
id_to_description = {} | |
with open('tags.txt', 'r') as file: | |
i = 0 | |
for line in file: | |
space = line.find(' ') | |
tag = line[:space] | |
description = line[space+1:-1] | |
tag_to_id[tag] = i | |
id_to_tag[i] = tag | |
id_to_description[i] = description | |
i += 1 | |
tag_to_id['None'] = 155 | |
id_to_tag['155'] = 'None' | |
id_to_description['155'] = 'No tag' | |
return (tag_to_id, id_to_tag, id_to_description) | |
tag_to_id, id_to_tag, id_to_description = load_tags_info() | |
def load_model(): | |
return DistilBertForSequenceClassification.from_pretrained('./') | |
def load_tokenizer(): | |
return AutoTokenizer.from_pretrained('./') | |
def top_xx(preds, xx=95): | |
tops = torch.argsort(preds, 1, descending=True) | |
total = 0 | |
index = 0 | |
result = [] | |
while total < xx / 100: | |
next_id = tops[0, index].item() | |
if next_id == 155: | |
index += 1 | |
continue | |
total += preds[0, next_id] | |
index += 1 | |
result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]}) | |
return result | |
model = load_model() | |
tokenizer = load_tokenizer() | |
temperature = 1 | |
st.title('ArXivTaxonomizer© (original version)') | |
st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer© functionality, please consider informing us.') | |
with st.form("Taxonomizer"): | |
title = st.text_area(label='Title', height=30) | |
abstract = st.text_area(label='Abstract (optional)', height=200) | |
xx = st.slider(label='Verbosity', min_value=1, max_value=99, value=95) | |
st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.') | |
submitted = st.form_submit_button("Taxonomize") | |
st.caption('We **do not** recommend using ArXivTaxonomizer© to choose tags for you new paper.') | |
if submitted: | |
if title == '': | |
st.markdown("You are most definitely abusing our service. Have the decency to at least enter a title.") | |
else: | |
prompt = 'Title: ' + title + ' Abstract: ' + abstract | |
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] | |
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) | |
tags = top_xx(preds, xx) | |
other_tags = [] | |
st.header('Inferred tags:') | |
for i, tag_data in enumerate(tags): | |
if i < 3: | |
st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')') | |
else: | |
if i == 3: | |
st.subheader('Other possible tags:') | |
st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')') |