Pyotr Lisov commited on
Commit
70b2ea0
·
1 Parent(s): ae1c8b3

Add article classifier app

Browse files
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "light"
3
+ primaryColor = "#1D4ED8"
4
+ backgroundColor = "#F7FAFC"
5
+ secondaryBackgroundColor = "#E8F1F8"
6
+ textColor = "#102A43"
7
+ font = "sans serif"
README.md CHANGED
@@ -1,20 +1,121 @@
1
  ---
2
- title: Article Classifier
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
  pinned: false
11
- short_description: A small bert-based model for classifying articles from Arxiv
12
  license: mit
 
13
  ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: arXiv Topic Classifier
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.33.0
8
+ app_file: app.py
 
9
  pinned: false
 
10
  license: mit
11
+ short_description: Transformer-powered topic classification for arXiv papers
12
  ---
13
 
14
+ # arXiv Topic Classifier
15
 
16
+ `arXiv Topic Classifier` is a Streamlit app for classifying research papers into arXiv-style topic categories from the paper title and abstract. The interface accepts the two fields separately, supports title-only inference, and returns the smallest prefix of labels whose cumulative probability exceeds 95%.
17
 
18
+ The project is designed as a lightweight end-to-end ML application: collect data, fine-tune a transformer classifier, package the trained model with local inference code, and expose the result through a public web interface.
19
+
20
+ ## Features
21
+
22
+ - topic prediction from `title` and `abstract`
23
+ - inference from `title` only when abstract is missing
24
+ - top-95% cumulative probability output
25
+ - full ranked list of class probabilities
26
+ - cached model loading for faster repeated requests
27
+ - self-contained deployment with local model weights
28
+
29
+ ## Categories
30
+
31
+ The current model predicts 10 categories:
32
+
33
+ - `astro-ph.GA`
34
+ - `cond-mat.mtrl-sci`
35
+ - `cs.CL`
36
+ - `cs.CV`
37
+ - `cs.RO`
38
+ - `econ.EM`
39
+ - `math.PR`
40
+ - `physics.optics`
41
+ - `q-bio.BM`
42
+ - `quant-ph`
43
+
44
+ ## Model
45
+
46
+ The production model is based on `distilbert-base-uncased` fine-tuned for multi-class text classification.
47
+
48
+ Configuration:
49
+
50
+ - max sequence length: `256`
51
+ - epochs: `3`
52
+ - learning rate: `2e-5`
53
+
54
+ The model consumes a single formatted text built from the input fields:
55
+
56
+ ```text
57
+ title: <paper title> abstract: <paper abstract>
58
+ ```
59
+
60
+ If the abstract is missing, inference falls back to:
61
+
62
+ ```text
63
+ title: <paper title>
64
+ ```
65
+
66
+ ## Dataset
67
+
68
+ The dataset was collected from the arXiv API and processed into train, validation, and test splits.
69
+
70
+ Prepared split sizes:
71
+
72
+ - train: `3120`
73
+ - validation: `391`
74
+ - test: `388`
75
+
76
+ ## Metrics
77
+
78
+ Evaluation metrics from the bundled model artifact:
79
+
80
+ - validation accuracy: `0.8696`
81
+ - validation macro-F1: `0.8696`
82
+ - test accuracy: `0.8789`
83
+ - test macro-F1: `0.8769`
84
+
85
+ ## Local Run
86
+
87
+ Install dependencies:
88
+
89
+ ```bash
90
+ python3 -m pip install -r requirements.txt
91
+ ```
92
+
93
+ Start the app:
94
+
95
+ ```bash
96
+ streamlit run app.py --server.port 8080
97
+ ```
98
+
99
+ ## Repository Layout
100
+
101
+ - `app.py` - Streamlit UI
102
+ - `inference.py` - model loading and inference pipeline
103
+ - `configs/app_config.json` - runtime configuration
104
+ - `artifacts/large_model/best_model/` - trained model weights and tokenizer
105
+ - `artifacts/large_model/metrics.json` - evaluation metrics
106
+ - `data/processed_large/label_mapping.json` - label mapping used by inference
107
+
108
+ ## Deployment
109
+
110
+ This repository is prepared for Hugging Face Spaces with `sdk: streamlit`. The app runs directly from local artifacts and does not require downloading model weights at runtime.
111
+
112
+ ## Example Use Cases
113
+
114
+ - quick topic tagging for arXiv drafts
115
+ - sanity-checking paper metadata before submission
116
+ - exploring how transformer classifiers separate neighboring scientific fields
117
+
118
+ ## Notes
119
+
120
+ - Predictions are limited by the training taxonomy and dataset coverage.
121
+ - The model is intended as a lightweight demo application, not a substitute for expert annotation.
app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import streamlit as st
8
+
9
+ from inference import ArticleClassifier, ClassifierError
10
+
11
+
12
+ PROJECT_DIR = Path(__file__).resolve().parent
13
+ CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json"
14
+ METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json"
15
+ DEFAULT_APP_CONFIG = {
16
+ "model_dir": "artifacts/large_model/best_model",
17
+ "labels_path": "data/processed_large/label_mapping.json",
18
+ "max_length": 256,
19
+ "coverage_threshold": 0.95,
20
+ "model_name": "distilbert-base-uncased",
21
+ "page_title": "arXiv Topic Classifier",
22
+ "page_icon": "📚",
23
+ "example_title": "Learning-based Visual Navigation for Mobile Robots",
24
+ "example_abstract": (
25
+ "We present a transformer-based navigation system that uses camera observations "
26
+ "and scene understanding to plan robust trajectories for indoor mobile robots."
27
+ ),
28
+ }
29
+
30
+
31
+ def load_app_config() -> dict[str, Any]:
32
+ if not CONFIG_PATH.exists():
33
+ return DEFAULT_APP_CONFIG.copy()
34
+
35
+ with CONFIG_PATH.open("r", encoding="utf-8") as fh:
36
+ config = json.load(fh)
37
+
38
+ merged_config = DEFAULT_APP_CONFIG.copy()
39
+ merged_config.update(config)
40
+ return merged_config
41
+
42
+
43
+ APP_CONFIG = load_app_config()
44
+ MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"])
45
+ LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"])
46
+ MAX_LENGTH = int(APP_CONFIG["max_length"])
47
+ COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"])
48
+
49
+
50
+ st.set_page_config(
51
+ page_title=str(APP_CONFIG["page_title"]),
52
+ page_icon=str(APP_CONFIG["page_icon"]),
53
+ layout="centered",
54
+ )
55
+
56
+
57
+ @st.cache_resource
58
+ def load_classifier() -> ArticleClassifier:
59
+ return ArticleClassifier(
60
+ model_dir=MODEL_DIR,
61
+ labels_path=LABELS_PATH,
62
+ max_length=MAX_LENGTH,
63
+ )
64
+
65
+
66
+ @st.cache_data
67
+ def load_metrics() -> dict | None:
68
+ if not METRICS_PATH.exists():
69
+ return None
70
+ import json
71
+
72
+ with METRICS_PATH.open("r", encoding="utf-8") as fh:
73
+ return json.load(fh)
74
+
75
+
76
+ def format_probability(probability: float) -> str:
77
+ return f"{probability * 100:.2f}%"
78
+
79
+
80
+ def format_threshold(threshold: float) -> str:
81
+ return f"{threshold * 100:.0f}%"
82
+
83
+
84
+ def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None:
85
+ for index, item in enumerate(predictions, start=1):
86
+ label = str(item["label"])
87
+ probability = float(item["probability"])
88
+ st.write(f"{index}. `{label}`")
89
+ st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability))
90
+
91
+
92
+ def main() -> None:
93
+ coverage_label = format_threshold(COVERAGE_THRESHOLD)
94
+
95
+ st.title(str(APP_CONFIG["page_title"]))
96
+ st.write(
97
+ "This demo predicts arXiv paper topics from the title and abstract using a transformer classifier."
98
+ )
99
+ st.caption(
100
+ "For homework evaluation, the app returns the smallest prefix of categories whose cumulative "
101
+ f"probability reaches {coverage_label}."
102
+ )
103
+ st.info(
104
+ "How to test: paste a real or synthetic paper title, optionally add an abstract, and press "
105
+ "`Predict categories`. If the abstract is empty, the model will classify from the title only."
106
+ )
107
+
108
+ classifier: ArticleClassifier | None = None
109
+ classifier_load_error: str | None = None
110
+
111
+ with st.sidebar:
112
+ try:
113
+ classifier = load_classifier()
114
+ except Exception as exc:
115
+ classifier_load_error = f"Model initialization error in load_classifier: {exc}"
116
+
117
+ metrics = load_metrics()
118
+ st.subheader("Evaluation Summary")
119
+ st.write(f"Model: `{APP_CONFIG['model_name']}`")
120
+ if classifier is not None:
121
+ st.write(f"Number of classes: `{len(classifier.labels)}`")
122
+ st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels))
123
+ else:
124
+ st.error(classifier_load_error or "Model initialization error: unknown error")
125
+ if metrics is not None:
126
+ validation_accuracy = metrics.get("validation", {}).get("eval_accuracy")
127
+ validation_f1 = metrics.get("validation", {}).get("eval_macro_f1")
128
+ test_accuracy = metrics.get("test", {}).get("test_accuracy")
129
+ test_f1 = metrics.get("test", {}).get("test_macro_f1")
130
+ if validation_accuracy is not None:
131
+ st.write(f"Validation accuracy: `{validation_accuracy:.4f}`")
132
+ if validation_f1 is not None:
133
+ st.write(f"Validation macro-F1: `{validation_f1:.4f}`")
134
+ if test_accuracy is not None:
135
+ st.write(f"Test accuracy: `{test_accuracy:.4f}`")
136
+ if test_f1 is not None:
137
+ st.write(f"Test macro-F1: `{test_f1:.4f}`")
138
+ st.write(
139
+ "Output rule: return categories until cumulative probability reaches "
140
+ f"{coverage_label}"
141
+ )
142
+
143
+ with st.expander("Example Input For Quick Check"):
144
+ st.markdown(
145
+ f"**Title:** {APP_CONFIG['example_title']}\n\n"
146
+ f"**Abstract:** {APP_CONFIG['example_abstract']}"
147
+ )
148
+
149
+ with st.form("prediction_form"):
150
+ title = st.text_input(
151
+ "Article title",
152
+ placeholder="Enter the article title",
153
+ )
154
+ abstract = st.text_area(
155
+ "Abstract",
156
+ placeholder="Enter the abstract (optional, but recommended)",
157
+ height=220,
158
+ )
159
+ predict_button = st.form_submit_button("Predict categories", type="primary")
160
+
161
+ if predict_button:
162
+ if classifier is None:
163
+ st.error(classifier_load_error or "Model initialization error: classifier is unavailable.")
164
+ return
165
+ if not title.strip() and not abstract.strip():
166
+ st.error("Input validation error in app: please enter at least a title or an abstract.")
167
+ return
168
+
169
+ with st.spinner("Running inference..."):
170
+ try:
171
+ full_predictions = classifier.predict(title=title, abstract=abstract)
172
+ predictions = classifier.select_top_k_by_probability_mass(
173
+ full_predictions,
174
+ threshold=COVERAGE_THRESHOLD,
175
+ )
176
+ except ValueError as exc:
177
+ st.error(str(exc))
178
+ return
179
+ except ClassifierError as exc:
180
+ st.error(f"Classifier error in prediction flow: {exc}")
181
+ return
182
+ except Exception as exc:
183
+ st.error(f"Unexpected inference error in app.main: {exc}")
184
+ return
185
+
186
+ best_prediction = predictions[0]
187
+ covered_probability = sum(float(item["probability"]) for item in predictions)
188
+ col1, col2, col3 = st.columns(3)
189
+ col1.metric("Top class", str(best_prediction["label"]))
190
+ col2.metric("Top probability", format_probability(float(best_prediction["probability"])))
191
+ col3.metric("Top-95% coverage", format_probability(covered_probability))
192
+
193
+ st.subheader("Top categories")
194
+ st.caption(
195
+ f"These are the categories returned by the assignment top-{coverage_label} rule."
196
+ )
197
+ render_prediction_rows(predictions)
198
+
199
+ with st.expander("Show Full Ranking"):
200
+ render_prediction_rows(full_predictions)
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
artifacts/large_model/best_model/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "bos_token_id": null,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "dtype": "float32",
11
+ "eos_token_id": null,
12
+ "hidden_dim": 3072,
13
+ "id2label": {
14
+ "0": "astro-ph.GA",
15
+ "1": "cond-mat.mtrl-sci",
16
+ "2": "cs.CL",
17
+ "3": "cs.CV",
18
+ "4": "cs.RO",
19
+ "5": "econ.EM",
20
+ "6": "math.PR",
21
+ "7": "physics.optics",
22
+ "8": "q-bio.BM",
23
+ "9": "quant-ph"
24
+ },
25
+ "initializer_range": 0.02,
26
+ "label2id": {
27
+ "astro-ph.GA": 0,
28
+ "cond-mat.mtrl-sci": 1,
29
+ "cs.CL": 2,
30
+ "cs.CV": 3,
31
+ "cs.RO": 4,
32
+ "econ.EM": 5,
33
+ "math.PR": 6,
34
+ "physics.optics": 7,
35
+ "q-bio.BM": 8,
36
+ "quant-ph": 9
37
+ },
38
+ "max_position_embeddings": 512,
39
+ "model_type": "distilbert",
40
+ "n_heads": 12,
41
+ "n_layers": 6,
42
+ "pad_token_id": 0,
43
+ "problem_type": "single_label_classification",
44
+ "qa_dropout": 0.1,
45
+ "seq_classif_dropout": 0.2,
46
+ "sinusoidal_pos_embds": false,
47
+ "tie_weights_": true,
48
+ "tie_word_embeddings": true,
49
+ "transformers_version": "5.5.0",
50
+ "use_cache": false,
51
+ "vocab_size": 30522
52
+ }
artifacts/large_model/best_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d67b810b6d752ad27b2d7ec1d3621e75366add609dfe8ef71a32fc3157f0b36
3
+ size 267857176
artifacts/large_model/best_model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
artifacts/large_model/best_model/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "is_local": false,
6
+ "mask_token": "[MASK]",
7
+ "model_max_length": 512,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "BertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
artifacts/large_model/best_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e12cd276b62258c10cf25085394da027012384e6c5100025957fde123d3c1fa
3
+ size 5265
artifacts/large_model/metrics.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "validation": {
3
+ "eval_loss": 0.4496897757053375,
4
+ "eval_accuracy": 0.8695652173913043,
5
+ "eval_macro_f1": 0.8695864631399942,
6
+ "eval_runtime": 7.4472,
7
+ "eval_samples_per_second": 52.503,
8
+ "eval_steps_per_second": 3.357,
9
+ "epoch": 3.0
10
+ },
11
+ "test": {
12
+ "test_loss": 0.4383482336997986,
13
+ "test_accuracy": 0.8788659793814433,
14
+ "test_macro_f1": 0.8768819420923114,
15
+ "test_runtime": 7.7828,
16
+ "test_samples_per_second": 49.853,
17
+ "test_steps_per_second": 3.212,
18
+ "epoch": 3.0
19
+ },
20
+ "model_name": "distilbert-base-uncased",
21
+ "num_classes": 10
22
+ }
configs/app_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_dir": "artifacts/large_model/best_model",
3
+ "labels_path": "data/processed_large/label_mapping.json",
4
+ "max_length": 256,
5
+ "coverage_threshold": 0.95,
6
+ "model_name": "distilbert-base-uncased",
7
+ "page_title": "arXiv Topic Classifier",
8
+ "page_icon": "📚",
9
+ "example_title": "Learning-based Visual Navigation for Mobile Robots",
10
+ "example_abstract": "We present a transformer-based navigation system that uses camera observations and scene understanding to plan robust trajectories for indoor mobile robots."
11
+ }
data/processed_large/label_mapping.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "astro-ph.GA": 0,
3
+ "cond-mat.mtrl-sci": 1,
4
+ "cs.CL": 2,
5
+ "cs.CV": 3,
6
+ "cs.RO": 4,
7
+ "econ.EM": 5,
8
+ "math.PR": 6,
9
+ "physics.optics": 7,
10
+ "q-bio.BM": 8,
11
+ "quant-ph": 9
12
+ }
inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
8
+
9
+
10
+ PROJECT_DIR = Path(__file__).resolve().parent
11
+ DEFAULT_MODEL_DIR = PROJECT_DIR / "artifacts" / "large_model" / "best_model"
12
+ DEFAULT_LABELS_PATH = PROJECT_DIR / "data" / "processed_large" / "label_mapping.json"
13
+
14
+
15
+ class ClassifierError(RuntimeError):
16
+ pass
17
+
18
+
19
+ class ArticleClassifier:
20
+ def __init__(
21
+ self,
22
+ model_dir: Path = DEFAULT_MODEL_DIR,
23
+ labels_path: Path = DEFAULT_LABELS_PATH,
24
+ max_length: int = 256,
25
+ ) -> None:
26
+ self.model_dir = Path(model_dir)
27
+ self.labels_path = Path(labels_path)
28
+ self.max_length = max_length
29
+ self.device = torch.device(
30
+ "mps"
31
+ if torch.backends.mps.is_available()
32
+ else "cuda"
33
+ if torch.cuda.is_available()
34
+ else "cpu"
35
+ )
36
+
37
+ if not self.labels_path.exists():
38
+ raise ClassifierError(
39
+ f"Failed to initialize classifier at labels loading stage: labels file not found at {self.labels_path}"
40
+ )
41
+ if not self.model_dir.exists():
42
+ raise ClassifierError(
43
+ f"Failed to initialize classifier at model loading stage: model directory not found at {self.model_dir}"
44
+ )
45
+
46
+ try:
47
+ with self.labels_path.open("r", encoding="utf-8") as fh:
48
+ self.label2id = json.load(fh)
49
+ except Exception as exc:
50
+ raise ClassifierError(
51
+ f"Failed to initialize classifier at labels loading stage: {exc}"
52
+ ) from exc
53
+
54
+ if not isinstance(self.label2id, dict) or not self.label2id:
55
+ raise ClassifierError(
56
+ "Failed to initialize classifier at labels loading stage: label mapping is empty or invalid"
57
+ )
58
+ self.id2label = {idx: label for label, idx in self.label2id.items()}
59
+
60
+ try:
61
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
62
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_dir)
63
+ self.model.to(self.device)
64
+ self.model.eval()
65
+ except Exception as exc:
66
+ raise ClassifierError(
67
+ f"Failed to initialize classifier at model loading stage: {exc}"
68
+ ) from exc
69
+
70
+ @property
71
+ def labels(self) -> list[str]:
72
+ return [self.id2label[idx] for idx in sorted(self.id2label)]
73
+
74
+ @staticmethod
75
+ def build_input_text(title: str, abstract: str) -> str:
76
+ clean_title = " ".join(title.split()).strip()
77
+ clean_abstract = " ".join(abstract.split()).strip()
78
+ if clean_abstract:
79
+ return f"title: {clean_title} abstract: {clean_abstract}"
80
+ return f"title: {clean_title}"
81
+
82
+ def predict(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
83
+ if not isinstance(title, str):
84
+ raise ValueError("Input validation error in predict: title must be a string.")
85
+ if not isinstance(abstract, str):
86
+ raise ValueError("Input validation error in predict: abstract must be a string.")
87
+ if not title.strip() and not abstract.strip():
88
+ raise ValueError(
89
+ "Input validation error in predict: please provide at least a title or an abstract."
90
+ )
91
+
92
+ text = self.build_input_text(title=title, abstract=abstract)
93
+ try:
94
+ encoded = self.tokenizer(
95
+ text,
96
+ return_tensors="pt",
97
+ truncation=True,
98
+ max_length=self.max_length,
99
+ )
100
+ encoded = {key: value.to(self.device) for key, value in encoded.items()}
101
+ except Exception as exc:
102
+ raise ClassifierError(f"Failed during tokenization stage: {exc}") from exc
103
+
104
+ try:
105
+ with torch.inference_mode():
106
+ logits = self.model(**encoded).logits
107
+ probabilities = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu()
108
+ except Exception as exc:
109
+ raise ClassifierError(f"Failed during model inference stage: {exc}") from exc
110
+
111
+ results: list[dict[str, float | str]] = []
112
+ try:
113
+ for class_id, probability in enumerate(probabilities.tolist()):
114
+ results.append(
115
+ {
116
+ "label": self.id2label[class_id],
117
+ "probability": float(probability),
118
+ }
119
+ )
120
+ except Exception as exc:
121
+ raise ClassifierError(f"Failed during prediction formatting stage: {exc}") from exc
122
+
123
+ results.sort(key=lambda item: item["probability"], reverse=True)
124
+ return results
125
+
126
+ @staticmethod
127
+ def select_top_95(
128
+ predictions: list[dict[str, float | str]],
129
+ ) -> list[dict[str, float | str]]:
130
+ return ArticleClassifier.select_top_k_by_probability_mass(
131
+ predictions=predictions,
132
+ threshold=0.95,
133
+ )
134
+
135
+ @staticmethod
136
+ def select_top_k_by_probability_mass(
137
+ predictions: list[dict[str, float | str]],
138
+ threshold: float = 0.95,
139
+ ) -> list[dict[str, float | str]]:
140
+ if not 0 < threshold <= 1:
141
+ raise ValueError("Probability mass threshold must be in the interval (0, 1].")
142
+
143
+ cumulative_probability = 0.0
144
+ top_predictions: list[dict[str, float | str]] = []
145
+
146
+ for item in predictions:
147
+ top_predictions.append(item)
148
+ cumulative_probability += float(item["probability"])
149
+ if cumulative_probability >= 0.95:
150
+ break
151
+
152
+ return top_predictions
153
+
154
+ def predict_top_95(self, title: str, abstract: str = "") -> list[dict[str, float | str]]:
155
+ predictions = self.predict(title=title, abstract=abstract)
156
+ return self.select_top_95(predictions)
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
1
+ numpy>=1.26
2
+ torch>=2.2,<3.0
3
+ transformers>=4.41
4
+ streamlit>=1.33,<2.0
5
+ safetensors>=0.4