File size: 1,810 Bytes
74d1d12
6dd148a
74d1d12
6dd148a
 
 
 
 
 
 
 
 
 
cd5626d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dd148a
 
 
 
 
 
 
 
 
 
 
cd5626d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import streamlit as st
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer

@st.cache_resource
def load_pipeline():
    model = AutoModelForSequenceClassification.from_pretrained("dmasloff/arxiv_distilbert")
    tokenizer = AutoTokenizer.from_pretrained("dmasloff/arxiv_distilbert")
    return pipeline("text-classification", model=model, tokenizer=tokenizer)

classifier = load_pipeline()

st.title("ArXiV article classification via DistilBERT")

topic_names = {
    0: 'High Energy Physics - Phenomenology (hep-ph)',
    1: 'Nuclear Experiment (nucl-ex)',
    2: 'High Energy Physics - Experiment (hep-ex)',
    3: 'Astrophysics (astro-ph)',
    4: 'Quantum Physics (quant-ph)',
    5: 'Mathematical Physics (math-ph)',
    6: 'High Energy Physics - Theory (hep-th)',
    7: 'Quantitative Biology (q-bio)',
    8: 'Nonlinear Sciences (nlin)',
    9: 'Computer Science (cs)',
    10: 'Quantitative Finance (q-fin)',
    11: 'Mathematics (math)',
    12: 'Condensed Matter (cond-mat)',
    13: 'High Energy Physics - Lattice (hep-lat)',
    14: 'Electrical Engineering and Systems Science (eess)',
    15: 'Physics (physics)',
    16: 'Nuclear Theory (nucl-th)',
    17: 'Statistics (stat)',
    18: 'Economics (econ)',
    19: 'General Relativity and Quantum Cosmology (gr-qc)'
}

title = st.text_area("Enter article's title")
abstract = st.text_area("Enter article's abstract")

if st.button("Submit"):
    if not title.strip() and not abstract.strip():
        st.warning("Please fill in at least one field.")
    else:
        full_text = f"{title}\n{abstract}".strip()
        with st.spinner("Classifying..."):
            result = classifier(full_text)
        st.success("Classification Result:")
        st.text(str(topic_names[int(result[0]["label"].split("_")[1])]))