Spaces:
Sleeping
Sleeping
Pyotr Lisov commited on
Commit ·
70b2ea0
1
Parent(s): ae1c8b3
Add article classifier app
Browse files- .streamlit/config.toml +7 -0
- README.md +114 -13
- app.py +204 -0
- artifacts/large_model/best_model/config.json +52 -0
- artifacts/large_model/best_model/model.safetensors +3 -0
- artifacts/large_model/best_model/tokenizer.json +0 -0
- artifacts/large_model/best_model/tokenizer_config.json +14 -0
- artifacts/large_model/best_model/training_args.bin +3 -0
- artifacts/large_model/metrics.json +22 -0
- configs/app_config.json +11 -0
- data/processed_large/label_mapping.json +12 -0
- inference.py +156 -0
- requirements.txt +5 -3
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
base = "light"
|
| 3 |
+
primaryColor = "#1D4ED8"
|
| 4 |
+
backgroundColor = "#F7FAFC"
|
| 5 |
+
secondaryBackgroundColor = "#E8F1F8"
|
| 6 |
+
textColor = "#102A43"
|
| 7 |
+
font = "sans serif"
|
README.md
CHANGED
|
@@ -1,20 +1,121 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
- streamlit
|
| 10 |
pinned: false
|
| 11 |
-
short_description: A small bert-based model for classifying articles from Arxiv
|
| 12 |
license: mit
|
|
|
|
| 13 |
---
|
| 14 |
|
| 15 |
-
#
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: arXiv Topic Classifier
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.33.0
|
| 8 |
+
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
|
|
|
| 10 |
license: mit
|
| 11 |
+
short_description: Transformer-powered topic classification for arXiv papers
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# arXiv Topic Classifier
|
| 15 |
|
| 16 |
+
`arXiv Topic Classifier` is a Streamlit app for classifying research papers into arXiv-style topic categories from the paper title and abstract. The interface accepts the two fields separately, supports title-only inference, and returns the smallest prefix of labels whose cumulative probability exceeds 95%.
|
| 17 |
|
| 18 |
+
The project is designed as a lightweight end-to-end ML application: collect data, fine-tune a transformer classifier, package the trained model with local inference code, and expose the result through a public web interface.
|
| 19 |
+
|
| 20 |
+
## Features
|
| 21 |
+
|
| 22 |
+
- topic prediction from `title` and `abstract`
|
| 23 |
+
- inference from `title` only when abstract is missing
|
| 24 |
+
- top-95% cumulative probability output
|
| 25 |
+
- full ranked list of class probabilities
|
| 26 |
+
- cached model loading for faster repeated requests
|
| 27 |
+
- self-contained deployment with local model weights
|
| 28 |
+
|
| 29 |
+
## Categories
|
| 30 |
+
|
| 31 |
+
The current model predicts 10 categories:
|
| 32 |
+
|
| 33 |
+
- `astro-ph.GA`
|
| 34 |
+
- `cond-mat.mtrl-sci`
|
| 35 |
+
- `cs.CL`
|
| 36 |
+
- `cs.CV`
|
| 37 |
+
- `cs.RO`
|
| 38 |
+
- `econ.EM`
|
| 39 |
+
- `math.PR`
|
| 40 |
+
- `physics.optics`
|
| 41 |
+
- `q-bio.BM`
|
| 42 |
+
- `quant-ph`
|
| 43 |
+
|
| 44 |
+
## Model
|
| 45 |
+
|
| 46 |
+
The production model is based on `distilbert-base-uncased` fine-tuned for multi-class text classification.
|
| 47 |
+
|
| 48 |
+
Configuration:
|
| 49 |
+
|
| 50 |
+
- max sequence length: `256`
|
| 51 |
+
- epochs: `3`
|
| 52 |
+
- learning rate: `2e-5`
|
| 53 |
+
|
| 54 |
+
The model consumes a single formatted text built from the input fields:
|
| 55 |
+
|
| 56 |
+
```text
|
| 57 |
+
title: <paper title> abstract: <paper abstract>
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
If the abstract is missing, inference falls back to:
|
| 61 |
+
|
| 62 |
+
```text
|
| 63 |
+
title: <paper title>
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Dataset
|
| 67 |
+
|
| 68 |
+
The dataset was collected from the arXiv API and processed into train, validation, and test splits.
|
| 69 |
+
|
| 70 |
+
Prepared split sizes:
|
| 71 |
+
|
| 72 |
+
- train: `3120`
|
| 73 |
+
- validation: `391`
|
| 74 |
+
- test: `388`
|
| 75 |
+
|
| 76 |
+
## Metrics
|
| 77 |
+
|
| 78 |
+
Evaluation metrics from the bundled model artifact:
|
| 79 |
+
|
| 80 |
+
- validation accuracy: `0.8696`
|
| 81 |
+
- validation macro-F1: `0.8696`
|
| 82 |
+
- test accuracy: `0.8789`
|
| 83 |
+
- test macro-F1: `0.8769`
|
| 84 |
+
|
| 85 |
+
## Local Run
|
| 86 |
+
|
| 87 |
+
Install dependencies:
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
python3 -m pip install -r requirements.txt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Start the app:
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
streamlit run app.py --server.port 8080
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## Repository Layout
|
| 100 |
+
|
| 101 |
+
- `app.py` - Streamlit UI
|
| 102 |
+
- `inference.py` - model loading and inference pipeline
|
| 103 |
+
- `configs/app_config.json` - runtime configuration
|
| 104 |
+
- `artifacts/large_model/best_model/` - trained model weights and tokenizer
|
| 105 |
+
- `artifacts/large_model/metrics.json` - evaluation metrics
|
| 106 |
+
- `data/processed_large/label_mapping.json` - label mapping used by inference
|
| 107 |
+
|
| 108 |
+
## Deployment
|
| 109 |
+
|
| 110 |
+
This repository is prepared for Hugging Face Spaces with `sdk: streamlit`. The app runs directly from local artifacts and does not require downloading model weights at runtime.
|
| 111 |
+
|
| 112 |
+
## Example Use Cases
|
| 113 |
+
|
| 114 |
+
- quick topic tagging for arXiv drafts
|
| 115 |
+
- sanity-checking paper metadata before submission
|
| 116 |
+
- exploring how transformer classifiers separate neighboring scientific fields
|
| 117 |
+
|
| 118 |
+
## Notes
|
| 119 |
+
|
| 120 |
+
- Predictions are limited by the training taxonomy and dataset coverage.
|
| 121 |
+
- The model is intended as a lightweight demo application, not a substitute for expert annotation.
|
app.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
|
| 9 |
+
from inference import ArticleClassifier, ClassifierError
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
PROJECT_DIR = Path(__file__).resolve().parent
|
| 13 |
+
CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json"
|
| 14 |
+
METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json"
|
| 15 |
+
DEFAULT_APP_CONFIG = {
|
| 16 |
+
"model_dir": "artifacts/large_model/best_model",
|
| 17 |
+
"labels_path": "data/processed_large/label_mapping.json",
|
| 18 |
+
"max_length": 256,
|
| 19 |
+
"coverage_threshold": 0.95,
|
| 20 |
+
"model_name": "distilbert-base-uncased",
|
| 21 |
+
"page_title": "arXiv Topic Classifier",
|
| 22 |
+
"page_icon": "📚",
|
| 23 |
+
"example_title": "Learning-based Visual Navigation for Mobile Robots",
|
| 24 |
+
"example_abstract": (
|
| 25 |
+
"We present a transformer-based navigation system that uses camera observations "
|
| 26 |
+
"and scene understanding to plan robust trajectories for indoor mobile robots."
|
| 27 |
+
),
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_app_config() -> dict[str, Any]:
|
| 32 |
+
if not CONFIG_PATH.exists():
|
| 33 |
+
return DEFAULT_APP_CONFIG.copy()
|
| 34 |
+
|
| 35 |
+
with CONFIG_PATH.open("r", encoding="utf-8") as fh:
|
| 36 |
+
config = json.load(fh)
|
| 37 |
+
|
| 38 |
+
merged_config = DEFAULT_APP_CONFIG.copy()
|
| 39 |
+
merged_config.update(config)
|
| 40 |
+
return merged_config
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
APP_CONFIG = load_app_config()
|
| 44 |
+
MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"])
|
| 45 |
+
LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"])
|
| 46 |
+
MAX_LENGTH = int(APP_CONFIG["max_length"])
|
| 47 |
+
COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
st.set_page_config(
|
| 51 |
+
page_title=str(APP_CONFIG["page_title"]),
|
| 52 |
+
page_icon=str(APP_CONFIG["page_icon"]),
|
| 53 |
+
layout="centered",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@st.cache_resource
|
| 58 |
+
def load_classifier() -> ArticleClassifier:
|
| 59 |
+
return ArticleClassifier(
|
| 60 |
+
model_dir=MODEL_DIR,
|
| 61 |
+
labels_path=LABELS_PATH,
|
| 62 |
+
max_length=MAX_LENGTH,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@st.cache_data
|
| 67 |
+
def load_metrics() -> dict | None:
|
| 68 |
+
if not METRICS_PATH.exists():
|
| 69 |
+
return None
|
| 70 |
+
import json
|
| 71 |
+
|
| 72 |
+
with METRICS_PATH.open("r", encoding="utf-8") as fh:
|
| 73 |
+
return json.load(fh)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def format_probability(probability: float) -> str:
|
| 77 |
+
return f"{probability * 100:.2f}%"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def format_threshold(threshold: float) -> str:
|
| 81 |
+
return f"{threshold * 100:.0f}%"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None:
|
| 85 |
+
for index, item in enumerate(predictions, start=1):
|
| 86 |
+
label = str(item["label"])
|
| 87 |
+
probability = float(item["probability"])
|
| 88 |
+
st.write(f"{index}. `{label}`")
|
| 89 |
+
st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main() -> None:
|
| 93 |
+
coverage_label = format_threshold(COVERAGE_THRESHOLD)
|
| 94 |
+
|
| 95 |
+
st.title(str(APP_CONFIG["page_title"]))
|
| 96 |
+
st.write(
|
| 97 |
+
"This demo predicts arXiv paper topics from the title and abstract using a transformer classifier."
|
| 98 |
+
)
|
| 99 |
+
st.caption(
|
| 100 |
+
"For homework evaluation, the app returns the smallest prefix of categories whose cumulative "
|
| 101 |
+
f"probability reaches {coverage_label}."
|
| 102 |
+
)
|
| 103 |
+
st.info(
|
| 104 |
+
"How to test: paste a real or synthetic paper title, optionally add an abstract, and press "
|
| 105 |
+
"`Predict categories`. If the abstract is empty, the model will classify from the title only."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
classifier: ArticleClassifier | None = None
|
| 109 |
+
classifier_load_error: str | None = None
|
| 110 |
+
|
| 111 |
+
with st.sidebar:
|
| 112 |
+
try:
|
| 113 |
+
classifier = load_classifier()
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
classifier_load_error = f"Model initialization error in load_classifier: {exc}"
|
| 116 |
+
|
| 117 |
+
metrics = load_metrics()
|
| 118 |
+
st.subheader("Evaluation Summary")
|
| 119 |
+
st.write(f"Model: `{APP_CONFIG['model_name']}`")
|
| 120 |
+
if classifier is not None:
|
| 121 |
+
st.write(f"Number of classes: `{len(classifier.labels)}`")
|
| 122 |
+
st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels))
|
| 123 |
+
else:
|
| 124 |
+
st.error(classifier_load_error or "Model initialization error: unknown error")
|
| 125 |
+
if metrics is not None:
|
| 126 |
+
validation_accuracy = metrics.get("validation", {}).get("eval_accuracy")
|
| 127 |
+
validation_f1 = metrics.get("validation", {}).get("eval_macro_f1")
|
| 128 |
+
test_accuracy = metrics.get("test", {}).get("test_accuracy")
|
| 129 |
+
test_f1 = metrics.get("test", {}).get("test_macro_f1")
|
| 130 |
+
if validation_accuracy is not None:
|
| 131 |
+
st.write(f"Validation accuracy: `{validation_accuracy:.4f}`")
|
| 132 |
+
if validation_f1 is not None:
|
| 133 |
+
st.write(f"Validation macro-F1: `{validation_f1:.4f}`")
|
| 134 |
+
if test_accuracy is not None:
|
| 135 |
+
st.write(f"Test accuracy: `{test_accuracy:.4f}`")
|
| 136 |
+
if test_f1 is not None:
|
| 137 |
+
st.write(f"Test macro-F1: `{test_f1:.4f}`")
|
| 138 |
+
st.write(
|
| 139 |
+
"Output rule: return categories until cumulative probability reaches "
|
| 140 |
+
f"{coverage_label}"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
with st.expander("Example Input For Quick Check"):
|
| 144 |
+
st.markdown(
|
| 145 |
+
f"**Title:** {APP_CONFIG['example_title']}\n\n"
|
| 146 |
+
f"**Abstract:** {APP_CONFIG['example_abstract']}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
with st.form("prediction_form"):
|
| 150 |
+
title = st.text_input(
|
| 151 |
+
"Article title",
|
| 152 |
+
placeholder="Enter the article title",
|
| 153 |
+
)
|
| 154 |
+
abstract = st.text_area(
|
| 155 |
+
"Abstract",
|
| 156 |
+
placeholder="Enter the abstract (optional, but recommended)",
|
| 157 |
+
height=220,
|
| 158 |
+
)
|
| 159 |
+
predict_button = st.form_submit_button("Predict categories", type="primary")
|
| 160 |
+
|
| 161 |
+
if predict_button:
|
| 162 |
+
if classifier is None:
|
| 163 |
+
st.error(classifier_load_error or "Model initialization error: classifier is unavailable.")
|
| 164 |
+
return
|
| 165 |
+
if not title.strip() and not abstract.strip():
|
| 166 |
+
st.error("Input validation error in app: please enter at least a title or an abstract.")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
with st.spinner("Running inference..."):
|
| 170 |
+
try:
|
| 171 |
+
full_predictions = classifier.predict(title=title, abstract=abstract)
|
| 172 |
+
predictions = classifier.select_top_k_by_probability_mass(
|
| 173 |
+
full_predictions,
|
| 174 |
+
threshold=COVERAGE_THRESHOLD,
|
| 175 |
+
)
|
| 176 |
+
except ValueError as exc:
|
| 177 |
+
st.error(str(exc))
|
| 178 |
+
return
|
| 179 |
+
except ClassifierError as exc:
|
| 180 |
+
st.error(f"Classifier error in prediction flow: {exc}")
|
| 181 |
+
return
|
| 182 |
+
except Exception as exc:
|
| 183 |
+
st.error(f"Unexpected inference error in app.main: {exc}")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
best_prediction = predictions[0]
|
| 187 |
+
covered_probability = sum(float(item["probability"]) for item in predictions)
|
| 188 |
+
col1, col2, col3 = st.columns(3)
|
| 189 |
+
col1.metric("Top class", str(best_prediction["label"]))
|
| 190 |
+
col2.metric("Top probability", format_probability(float(best_prediction["probability"])))
|
| 191 |
+
col3.metric("Top-95% coverage", format_probability(covered_probability))
|
| 192 |
+
|
| 193 |
+
st.subheader("Top categories")
|
| 194 |
+
st.caption(
|
| 195 |
+
f"These are the categories returned by the assignment top-{coverage_label} rule."
|
| 196 |
+
)
|
| 197 |
+
render_prediction_rows(predictions)
|
| 198 |
+
|
| 199 |
+
with st.expander("Show Full Ranking"):
|
| 200 |
+
render_prediction_rows(full_predictions)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
main()
|
artifacts/large_model/best_model/config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "astro-ph.GA",
|
| 15 |
+
"1": "cond-mat.mtrl-sci",
|
| 16 |
+
"2": "cs.CL",
|
| 17 |
+
"3": "cs.CV",
|
| 18 |
+
"4": "cs.RO",
|
| 19 |
+
"5": "econ.EM",
|
| 20 |
+
"6": "math.PR",
|
| 21 |
+
"7": "physics.optics",
|
| 22 |
+
"8": "q-bio.BM",
|
| 23 |
+
"9": "quant-ph"
|
| 24 |
+
},
|
| 25 |
+
"initializer_range": 0.02,
|
| 26 |
+
"label2id": {
|
| 27 |
+
"astro-ph.GA": 0,
|
| 28 |
+
"cond-mat.mtrl-sci": 1,
|
| 29 |
+
"cs.CL": 2,
|
| 30 |
+
"cs.CV": 3,
|
| 31 |
+
"cs.RO": 4,
|
| 32 |
+
"econ.EM": 5,
|
| 33 |
+
"math.PR": 6,
|
| 34 |
+
"physics.optics": 7,
|
| 35 |
+
"q-bio.BM": 8,
|
| 36 |
+
"quant-ph": 9
|
| 37 |
+
},
|
| 38 |
+
"max_position_embeddings": 512,
|
| 39 |
+
"model_type": "distilbert",
|
| 40 |
+
"n_heads": 12,
|
| 41 |
+
"n_layers": 6,
|
| 42 |
+
"pad_token_id": 0,
|
| 43 |
+
"problem_type": "single_label_classification",
|
| 44 |
+
"qa_dropout": 0.1,
|
| 45 |
+
"seq_classif_dropout": 0.2,
|
| 46 |
+
"sinusoidal_pos_embds": false,
|
| 47 |
+
"tie_weights_": true,
|
| 48 |
+
"tie_word_embeddings": true,
|
| 49 |
+
"transformers_version": "5.5.0",
|
| 50 |
+
"use_cache": false,
|
| 51 |
+
"vocab_size": 30522
|
| 52 |
+
}
|
artifacts/large_model/best_model/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d67b810b6d752ad27b2d7ec1d3621e75366add609dfe8ef71a32fc3157f0b36
|
| 3 |
+
size 267857176
|
artifacts/large_model/best_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
artifacts/large_model/best_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"do_lower_case": true,
|
| 5 |
+
"is_local": false,
|
| 6 |
+
"mask_token": "[MASK]",
|
| 7 |
+
"model_max_length": 512,
|
| 8 |
+
"pad_token": "[PAD]",
|
| 9 |
+
"sep_token": "[SEP]",
|
| 10 |
+
"strip_accents": null,
|
| 11 |
+
"tokenize_chinese_chars": true,
|
| 12 |
+
"tokenizer_class": "BertTokenizer",
|
| 13 |
+
"unk_token": "[UNK]"
|
| 14 |
+
}
|
artifacts/large_model/best_model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e12cd276b62258c10cf25085394da027012384e6c5100025957fde123d3c1fa
|
| 3 |
+
size 5265
|
artifacts/large_model/metrics.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"validation": {
|
| 3 |
+
"eval_loss": 0.4496897757053375,
|
| 4 |
+
"eval_accuracy": 0.8695652173913043,
|
| 5 |
+
"eval_macro_f1": 0.8695864631399942,
|
| 6 |
+
"eval_runtime": 7.4472,
|
| 7 |
+
"eval_samples_per_second": 52.503,
|
| 8 |
+
"eval_steps_per_second": 3.357,
|
| 9 |
+
"epoch": 3.0
|
| 10 |
+
},
|
| 11 |
+
"test": {
|
| 12 |
+
"test_loss": 0.4383482336997986,
|
| 13 |
+
"test_accuracy": 0.8788659793814433,
|
| 14 |
+
"test_macro_f1": 0.8768819420923114,
|
| 15 |
+
"test_runtime": 7.7828,
|
| 16 |
+
"test_samples_per_second": 49.853,
|
| 17 |
+
"test_steps_per_second": 3.212,
|
| 18 |
+
"epoch": 3.0
|
| 19 |
+
},
|
| 20 |
+
"model_name": "distilbert-base-uncased",
|
| 21 |
+
"num_classes": 10
|
| 22 |
+
}
|
configs/app_config.json
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_dir": "artifacts/large_model/best_model",
|
| 3 |
+
"labels_path": "data/processed_large/label_mapping.json",
|
| 4 |
+
"max_length": 256,
|
| 5 |
+
"coverage_threshold": 0.95,
|
| 6 |
+
"model_name": "distilbert-base-uncased",
|
| 7 |
+
"page_title": "arXiv Topic Classifier",
|
| 8 |
+
"page_icon": "📚",
|
| 9 |
+
"example_title": "Learning-based Visual Navigation for Mobile Robots",
|
| 10 |
+
"example_abstract": "We present a transformer-based navigation system that uses camera observations and scene understanding to plan robust trajectories for indoor mobile robots."
|
| 11 |
+
}
|
data/processed_large/label_mapping.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"astro-ph.GA": 0,
|
| 3 |
+
"cond-mat.mtrl-sci": 1,
|
| 4 |
+
"cs.CL": 2,
|
| 5 |
+
"cs.CV": 3,
|
| 6 |
+
"cs.RO": 4,
|
| 7 |
+
"econ.EM": 5,
|
| 8 |
+
"math.PR": 6,
|
| 9 |
+
"physics.optics": 7,
|
| 10 |
+
"q-bio.BM": 8,
|
| 11 |
+
"quant-ph": 9
|
| 12 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
PROJECT_DIR = Path(__file__).resolve().parent
|
| 11 |
+
DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model"
|
| 12 |
+
DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ClassifierError(RuntimeError):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ArticleClassifier:
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model_dir: Path = DEFAULT_MODEL_DIR,
|
| 23 |
+
labels_path: Path = DEFAULT_LABELS_PATH,
|
| 24 |
+
max_length: int = 256,
|
| 25 |
+
) -> None:
|
| 26 |
+
self.model_dir = Path(model_dir)
|
| 27 |
+
self.labels_path = Path(labels_path)
|
| 28 |
+
self.max_length = max_length
|
| 29 |
+
self.device = torch.device(
|
| 30 |
+
"mps"
|
| 31 |
+
if torch.backends.mps.is_available()
|
| 32 |
+
else "cuda"
|
| 33 |
+
if torch.cuda.is_available()
|
| 34 |
+
else "cpu"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if not self.labels_path.exists():
|
| 38 |
+
raise ClassifierError(
|
| 39 |
+
f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}"
|
| 40 |
+
)
|
| 41 |
+
if not self.model_dir.exists():
|
| 42 |
+
raise ClassifierError(
|
| 43 |
+
f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
with self.labels_path.open("r", encoding="utf-8") as fh:
|
| 48 |
+
self.label2id = json.load(fh)
|
| 49 |
+
except Exception as exc:
|
| 50 |
+
raise ClassifierError(
|
| 51 |
+
f"Failed to initialize classifier at labels loading stage: {exc}"
|
| 52 |
+
) from exc
|
| 53 |
+
|
| 54 |
+
if not isinstance(self.label2id, dict) or not self.label2id:
|
| 55 |
+
raise ClassifierError(
|
| 56 |
+
"Failed to initialize classifier at labels loading stage: label mapping is empty or invalid"
|
| 57 |
+
)
|
| 58 |
+
self.id2label = {idx: label for label, idx in self.label2id.items()}
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
| 62 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
|
| 63 |
+
self.model.to(self.device)
|
| 64 |
+
self.model.eval()
|
| 65 |
+
except Exception as exc:
|
| 66 |
+
raise ClassifierError(
|
| 67 |
+
f"Failed to initialize classifier at model loading stage: {exc}"
|
| 68 |
+
) from exc
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def labels(self) -> list[str]:
|
| 72 |
+
return [self.id2label[idx] for idx in sorted(self.id2label)]
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def build_input_text(title: str, abstract: str) -> str:
|
| 76 |
+
clean_title = " ".join(title.split()).strip()
|
| 77 |
+
clean_abstract = " ".join(abstract.split()).strip()
|
| 78 |
+
if clean_abstract:
|
| 79 |
+
return f"title: {clean_title} abstract: {clean_abstract}"
|
| 80 |
+
return f"title: {clean_title}"
|
| 81 |
+
|
| 82 |
+
def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
|
| 83 |
+
if not isinstance(title, str):
|
| 84 |
+
raise ValueError("Input validation error in predict: title must be a string.")
|
| 85 |
+
if not isinstance(abstract, str):
|
| 86 |
+
raise ValueError("Input validation error in predict: abstract must be a string.")
|
| 87 |
+
if not title.strip() and not abstract.strip():
|
| 88 |
+
raise ValueError(
|
| 89 |
+
"Input validation error in predict: please provide at least a title or an abstract."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
text = self.build_input_text(title=title, abstract=abstract)
|
| 93 |
+
try:
|
| 94 |
+
encoded = self.tokenizer(
|
| 95 |
+
text,
|
| 96 |
+
return_tensors="pt",
|
| 97 |
+
truncation=True,
|
| 98 |
+
max_length=self.max_length,
|
| 99 |
+
)
|
| 100 |
+
encoded = {key: value.to(self.device) for key, value in encoded.items()}
|
| 101 |
+
except Exception as exc:
|
| 102 |
+
raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
with torch.inference_mode():
|
| 106 |
+
logits = self.model(**encoded).logits
|
| 107 |
+
probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu()
|
| 108 |
+
except Exception as exc:
|
| 109 |
+
raise ClassifierError(f"Failed during model inference stage: {exc}") from exc
|
| 110 |
+
|
| 111 |
+
results: list[dict[str, float | str]] = []
|
| 112 |
+
try:
|
| 113 |
+
for class_id, probability in enumerate(probabilities.tolist()):
|
| 114 |
+
results.append(
|
| 115 |
+
{
|
| 116 |
+
"label": self.id2label[class_id],
|
| 117 |
+
"probability": float(probability),
|
| 118 |
+
}
|
| 119 |
+
)
|
| 120 |
+
except Exception as exc:
|
| 121 |
+
raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc
|
| 122 |
+
|
| 123 |
+
results.sort(key=lambda item: item["probability"], reverse=True)
|
| 124 |
+
return results
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def select_top_95(
|
| 128 |
+
predictions: list[dict[str, float | str]],
|
| 129 |
+
) -> list[dict[str, float | str]]:
|
| 130 |
+
return ArticleClassifier.select_top_k_by_probability_mass(
|
| 131 |
+
predictions=predictions,
|
| 132 |
+
threshold=0.95,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def select_top_k_by_probability_mass(
|
| 137 |
+
predictions: list[dict[str, float | str]],
|
| 138 |
+
threshold: float = 0.95,
|
| 139 |
+
) -> list[dict[str, float | str]]:
|
| 140 |
+
if not 0 < threshold <= 1:
|
| 141 |
+
raise ValueError("Probability mass threshold must be in the interval (0, 1].")
|
| 142 |
+
|
| 143 |
+
cumulative_probability = 0.0
|
| 144 |
+
top_predictions: list[dict[str, float | str]] = []
|
| 145 |
+
|
| 146 |
+
for item in predictions:
|
| 147 |
+
top_predictions.append(item)
|
| 148 |
+
cumulative_probability += float(item["probability"])
|
| 149 |
+
if cumulative_probability >= 0.95:
|
| 150 |
+
break
|
| 151 |
+
|
| 152 |
+
return top_predictions
|
| 153 |
+
|
| 154 |
+
def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
|
| 155 |
+
predictions = self.predict(title=title, abstract=abstract)
|
| 156 |
+
return self.select_top_95(predictions)
|
requirements.txt
CHANGED
|
@@ -1,3 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.26
|
| 2 |
+
torch>=2.2,<3.0
|
| 3 |
+
transformers>=4.41
|
| 4 |
+
streamlit>=1.33,<2.0
|
| 5 |
+
safetensors>=0.4
|