import numpy as np import pandas as pd import torch from transformers import ( AutoTokenizer, AutoModelForSequenceClassification ) import streamlit as st DEPLOYMENT_PATH = '.' @st.cache_resource def setup(): model_name = 'distilbert-base-cased' model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint') model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name) idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze() return model, tokenizer, idx2category @torch.no_grad() def get_probas(title, abstract=None): inputs = tokenizer( title, abstract, padding=True, truncation=True, return_tensors='pt' ) outputs = model(**inputs) logits = outputs.logits probas = ( torch.sigmoid(logits) .detach().numpy().reshape(-1) ) return probas model, tokenizer, idx2category = setup() num_categories = len(idx2category) def get_categories_by_threshold(probas, threshold=0.3): categories = [ idx2category[idx] for idx in range(num_categories) if probas[idx] > threshold ] return categories def get_top_categories(probas, num_predictions=5): categories = [ idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions] ] return categories st.title('ArXiv Papers Categorization') title_input = st.text_input('Enter the title of paper:') abstract_input = st.text_area('Enter the abstract (optional):') IS_READY = len(title_input) > 0 if IS_READY and st.button('Categorize'): probas = get_probas(title_input, abstract_input) categories_predicted = get_categories_by_threshold(probas) if len(categories_predicted) == 0: categories_predicted = get_top_categories(probas) st.write('Relevant arXiv categories:') for category in categories_predicted: st.markdown(f'- `{category}`')