Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification | |
| ) | |
| import streamlit as st | |
| DEPLOYMENT_PATH = '.' | |
| def setup(): | |
| model_name = 'distilbert-base-cased' | |
| model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint') | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze() | |
| return model, tokenizer, idx2category | |
| def get_probas(title, abstract=None): | |
| inputs = tokenizer( | |
| title, | |
| abstract, | |
| padding=True, | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probas = ( | |
| torch.sigmoid(logits) | |
| .detach().numpy().reshape(-1) | |
| ) | |
| return probas | |
| model, tokenizer, idx2category = setup() | |
| num_categories = len(idx2category) | |
| def get_categories_by_threshold(probas, threshold=0.3): | |
| categories = [ | |
| idx2category[idx] for idx in range(num_categories) | |
| if probas[idx] > threshold | |
| ] | |
| return categories | |
| def get_top_categories(probas, num_predictions=5): | |
| categories = [ | |
| idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions] | |
| ] | |
| return categories | |
| st.title('ArXiv Papers Categorization') | |
| title_input = st.text_input('Enter the title of paper:') | |
| abstract_input = st.text_area('Enter the abstract (optional):') | |
| IS_READY = len(title_input) > 0 | |
| if IS_READY and st.button('Categorize'): | |
| probas = get_probas(title_input, abstract_input) | |
| categories_predicted = get_categories_by_threshold(probas) | |
| if len(categories_predicted) == 0: | |
| categories_predicted = get_top_categories(probas) | |
| st.write('Relevant arXiv categories:') | |
| for category in categories_predicted: | |
| st.markdown(f'- `{category}`') | |