gvbazhenov's picture
deploy
7120cc1
import numpy as np
import pandas as pd
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification
)
import streamlit as st
DEPLOYMENT_PATH = '.'
@st.cache_resource
def setup():
model_name = 'distilbert-base-cased'
model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint')
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze()
return model, tokenizer, idx2category
@torch.no_grad()
def get_probas(title, abstract=None):
inputs = tokenizer(
title,
abstract,
padding=True,
truncation=True,
return_tensors='pt'
)
outputs = model(**inputs)
logits = outputs.logits
probas = (
torch.sigmoid(logits)
.detach().numpy().reshape(-1)
)
return probas
model, tokenizer, idx2category = setup()
num_categories = len(idx2category)
def get_categories_by_threshold(probas, threshold=0.3):
categories = [
idx2category[idx] for idx in range(num_categories)
if probas[idx] > threshold
]
return categories
def get_top_categories(probas, num_predictions=5):
categories = [
idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions]
]
return categories
st.title('ArXiv Papers Categorization')
title_input = st.text_input('Enter the title of paper:')
abstract_input = st.text_area('Enter the abstract (optional):')
IS_READY = len(title_input) > 0
if IS_READY and st.button('Categorize'):
probas = get_probas(title_input, abstract_input)
categories_predicted = get_categories_by_threshold(probas)
if len(categories_predicted) == 0:
categories_predicted = get_top_categories(probas)
st.write('Relevant arXiv categories:')
for category in categories_predicted:
st.markdown(f'- `{category}`')