|
import streamlit as st |
|
|
|
from transformers import AutoTokenizer, DistilBertForSequenceClassification |
|
import torch |
|
from torch.nn.functional import softmax |
|
|
|
|
|
|
|
base_model_name = 'distilbert-base-uncased' |
|
|
|
@st.cache |
|
def load_tags_info(): |
|
tag_to_id = {} |
|
id_to_tag = {} |
|
id_to_description = {} |
|
with open('tags.txt', 'r') as file: |
|
i = 0 |
|
for line in file: |
|
|
|
space = line.find(' ') |
|
|
|
tag = line[:space] |
|
description = line[space+1:-1] |
|
|
|
tag_to_id[tag] = i |
|
id_to_tag[i] = tag |
|
id_to_description[i] = description |
|
|
|
i += 1 |
|
|
|
return (tag_to_id, id_to_tag, id_to_description) |
|
|
|
tag_to_id, id_to_tag, id_to_description = load_tags_info() |
|
|
|
@st.cache |
|
def load_model(): |
|
return DistilBertForSequenceClassification.from_pretrained('./') |
|
|
|
def load_tokenizer(): |
|
return AutoTokenizer.from_pretrained('distilbert-base-uncased') |
|
|
|
def top_xx(preds, xx=95): |
|
tops = torch.argsort(preds, 1, descending=True) |
|
total = 0 |
|
index = 0 |
|
result = [] |
|
while total < xx / 100: |
|
next_id = tops[0, index].item() |
|
total += preds[0, next_id] |
|
index += 1 |
|
result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]}) |
|
return result |
|
|
|
model = load_model() |
|
tokenizer = load_tokenizer() |
|
temperature = 1 |
|
|
|
st.title('ArXivTaxonomizer© (original version)') |
|
st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer© functionality, please consider informing us.') |
|
|
|
with st.form("Taxonomizer"): |
|
|
|
title = st.text_area(label='Title', height=30) |
|
abstract = st.text_area(label='Abstract (optional)', height=200) |
|
st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.') |
|
|
|
submitted = st.form_submit_button("Taxonomize") |
|
st.caption('We **do not** recommend using ArXivTaxonomizer© to choose tags for you new paper.') |
|
if submitted: |
|
if title == '': |
|
st.markdown("You are most definitely abusing our service. Have the decency to at least enter a title.") |
|
else: |
|
prompt = 'Title: ' + title + ' Abstract: ' + abstract |
|
tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] |
|
preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) |
|
tags = top_xx(preds) |
|
other_tags = [] |
|
st.header('Inferred tags:') |
|
for i, tag_data in enumerate(tags): |
|
st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')') |