Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification | |
) | |
import streamlit as st | |
DEPLOYMENT_PATH = '.' | |
def setup(): | |
model_name = 'distilbert-base-cased' | |
model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint') | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze() | |
return model, tokenizer, idx2category | |
def get_probas(title, abstract=None): | |
inputs = tokenizer( | |
title, | |
abstract, | |
padding=True, | |
truncation=True, | |
return_tensors='pt' | |
) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probas = ( | |
torch.sigmoid(logits) | |
.detach().numpy().reshape(-1) | |
) | |
return probas | |
model, tokenizer, idx2category = setup() | |
num_categories = len(idx2category) | |
def get_categories_by_threshold(probas, threshold=0.3): | |
categories = [ | |
idx2category[idx] for idx in range(num_categories) | |
if probas[idx] > threshold | |
] | |
return categories | |
def get_top_categories(probas, num_predictions=5): | |
categories = [ | |
idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions] | |
] | |
return categories | |
st.title('ArXiv Papers Categorization') | |
title_input = st.text_input('Enter the title of paper:') | |
abstract_input = st.text_area('Enter the abstract (optional):') | |
IS_READY = len(title_input) > 0 | |
if IS_READY and st.button('Categorize'): | |
probas = get_probas(title_input, abstract_input) | |
categories_predicted = get_categories_by_threshold(probas) | |
if len(categories_predicted) == 0: | |
categories_predicted = get_top_categories(probas) | |
st.write('Relevant arXiv categories:') | |
for category in categories_predicted: | |
st.markdown(f'- `{category}`') | |