import streamlit as st from transformers import AutoTokenizer, DistilBertForSequenceClassification import torch from torch.nn.functional import softmax base_model_name = 'distilbert-base-uncased' @st.cache def load_tags_info(): tag_to_id = {} id_to_tag = {} id_to_description = {} with open('tags.txt', 'r') as file: i = 0 for line in file: space = line.find(' ') tag = line[:space] description = line[space+1:-1] tag_to_id[tag] = i id_to_tag[i] = tag id_to_description[i] = description i += 1 tag_to_id['None'] = 155 id_to_tag['155'] = 'None' id_to_description['155'] = 'No tag' return (tag_to_id, id_to_tag, id_to_description) tag_to_id, id_to_tag, id_to_description = load_tags_info() @st.cache def load_model(): return DistilBertForSequenceClassification.from_pretrained('./') def load_tokenizer(): return AutoTokenizer.from_pretrained('./') def top_xx(preds, xx=95): tops = torch.argsort(preds, 1, descending=True) total = 0 index = 0 result = [] while total < xx / 100: next_id = tops[0, index].item() if next_id == 155: index += 1 continue total += preds[0, next_id] index += 1 result.append({'tag': id_to_tag[next_id], 'description': id_to_description[next_id]}) return result model = load_model() tokenizer = load_tokenizer() temperature = 1 st.title('ArXivTaxonomizer© (original version)') st.caption('If you are aware of any other public services which are illegally providing the ArXivTaxonomizer© functionality, please consider informing us.') with st.form("Taxonomizer"): title = st.text_area(label='Title', height=30) abstract = st.text_area(label='Abstract (optional)', height=200) xx = st.slider(label='Verbosity', min_value=1, max_value=99, value=95) st.caption('Lower values will generate a few best guesses. Higher values will lead to a comprehensive list of topics that our model considers relevant. \nEmpirically, values arond 70 work best and generate a list of 3-5 guesses.') submitted = st.form_submit_button("Taxonomize") st.caption('We **do not** recommend using ArXivTaxonomizer© to choose tags for you new paper.') if submitted: if title == '': st.markdown("You are most definitely abusing our service. Have the decency to at least enter a title.") else: prompt = 'Title: ' + title + ' Abstract: ' + abstract tokens = tokenizer(prompt, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] preds = softmax(model(tokens.reshape(1, -1)).logits / temperature, dim=1) tags = top_xx(preds, xx) other_tags = [] st.header('Inferred tags:') for i, tag_data in enumerate(tags): if i < 3: st.markdown('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')') else: if i == 3: st.subheader('Other possible tags:') st.caption('* ' + tag_data['tag'] + ' (' + tag_data['description'] + ')')