File size: 5,392 Bytes
82d81c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

st.set_page_config(
    page_title="ArXiv Paper Classifier",
    page_icon="📚",
)

st.title("ArXiv Paper Classifier")
st.markdown(
    """
This app classifies papers based on their abstract.
Enter the paper details and the model will predict the most likely topic categories.
"""
)


@st.cache_resource
def load_model_and_tokenizer():
    model_path = "goldov/arxiv-classifier-debertav3"  # TODO: change later

    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    return model, tokenizer, model.config.id2label


with st.spinner("Loading model... This may take a minute."):
    model, tokenizer, id2label = load_model_and_tokenizer()


st.subheader("Paper Information")
with st.form(key="paper_form"):
    title = st.text_input("Title", placeholder="Enter the paper title")
    abstract = st.text_area("Abstract (optional)", placeholder="Enter the paper abstract (optional)")
    submit_button = st.form_submit_button(label="Classify Paper")


def predict_topics(title, abstract=""):

    if abstract:
        text = f"Title: {title} Abstract: {abstract}"
    else:
        text = f"Title: {title}"

    tokens_info = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")

    model.eval()
    model.cpu()
    with torch.no_grad():
        out = model(**tokens_info)
        probs = torch.nn.functional.softmax(out.logits, dim=-1).squeeze(0)

    sorted_probs, sorted_indices = torch.sort(probs, descending=True)

    cumulative_probs = torch.cumsum(sorted_probs, dim=0)

    cutoff_idx = torch.where(cumulative_probs >= 0.95)[0][0].item() + 1

    results = []
    for i in range(cutoff_idx):
        category = sorted_indices[i].item()
        category = id2label[category]
        probability = sorted_probs[i].item()
        results.append((category, probability))

    return results


if submit_button:
    if not title:
        st.error("Please enter a paper title.")
    else:
        with st.spinner("Classifying..."):
            results = predict_topics(title, abstract)

        st.subheader("Prediction Results")

        if abstract:
            st.text(f"Classification based on title and abstract")
        else:
            st.text(f"Classification based on title")

        categories = [r[0] for r in results]
        probabilities = [r[1] for r in results]

        formatted_probs = [f"{p:.2%}" for p in probabilities]

        st.markdown("#### Top Categories")

        col1, col2 = st.columns([3, 1])
        with col1:
            st.markdown("**Category**")
        with col2:
            st.markdown("**Probability**")

        for category, prob in results:
            col1, col2 = st.columns([3, 1])
            with col1:
                st.markdown(f"{category}")
            with col2:
                st.progress(prob)
                st.markdown(f"{prob:.2%}")

        total_prob = sum(probabilities)
        st.info(f"Total probability covered: {total_prob:.2%}")


# Add example section
if st.button("Try An Example!"):
    example_title = "Attention Is All You Need"
    example_abstract = """The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.
    The best performing models also connect the encoder and decoder through an attention mechanism.
    We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely.
    Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train.
    Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU.
    On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature.
    We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data."""

    with st.spinner("Classifying example..."):
        results = predict_topics(example_title, example_abstract)
    st.subheader("Example Prediction Results")
    st.text(f"Title: {example_title}")
    st.text(f"Abstract: {example_abstract}")
    st.text("Classification based on title and abstract")

    probabilities = [r[1] for r in results]
    st.markdown("#### Top Categories")

    # Create a more visually appealing table
    col1, col2 = st.columns([3, 1])
    with col1:
        st.markdown("**Category**")
    with col2:
        st.markdown("**Probability**")

    for category, prob in results:
        col1, col2 = st.columns([3, 1])
        with col1:
            st.markdown(f"{category}")
        with col2:
            st.progress(prob)
            st.markdown(f"{prob:.1%}")

    total_prob = sum(probabilities)
    st.info(f"Total probability covered: {total_prob:.1%}")

st.markdown("---")
st.markdown("ArXiv Paper Classifier by Ivan Goldov")