File size: 1,946 Bytes
7120cc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import pandas as pd
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification
)
import streamlit as st

DEPLOYMENT_PATH = '.'

@st.cache_resource
def setup():
    model_name = 'distilbert-base-cased'
    model = AutoModelForSequenceClassification.from_pretrained(f'{DEPLOYMENT_PATH}/checkpoint')
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    idx2category = pd.read_csv(f'{DEPLOYMENT_PATH}/categories.csv').values.squeeze()
    return model, tokenizer, idx2category

@torch.no_grad()
def get_probas(title, abstract=None):
    inputs = tokenizer(
        title,
        abstract,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    outputs = model(**inputs)
    logits = outputs.logits
    probas = (
        torch.sigmoid(logits)
        .detach().numpy().reshape(-1)
    )
    return probas

model, tokenizer, idx2category = setup()
num_categories = len(idx2category)

def get_categories_by_threshold(probas, threshold=0.3):
    categories = [
        idx2category[idx] for idx in range(num_categories)
        if probas[idx] > threshold
    ]
    return categories

def get_top_categories(probas, num_predictions=5):
    categories = [
        idx2category[idx] for idx in np.argsort(probas)[::-1][:num_predictions]
    ]
    return categories

st.title('ArXiv Papers Categorization')
title_input = st.text_input('Enter the title of paper:')
abstract_input = st.text_area('Enter the abstract (optional):')

IS_READY = len(title_input) > 0

if IS_READY and st.button('Categorize'):
    probas = get_probas(title_input, abstract_input)
    categories_predicted = get_categories_by_threshold(probas)

    if len(categories_predicted) == 0:
        categories_predicted = get_top_categories(probas)

    st.write('Relevant arXiv categories:')
    for category in categories_predicted:
        st.markdown(f'- `{category}`')