Pyotr Lisov
Add article classifier app
70b2ea0
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import streamlit as st
from inference import ArticleClassifier, ClassifierError
PROJECT_DIR = Path(__file__).resolve().parent
CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json"
METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json"
DEFAULT_APP_CONFIG = {
"model_dir": "artifacts/large_model/best_model",
"labels_path": "data/processed_large/label_mapping.json",
"max_length": 256,
"coverage_threshold": 0.95,
"model_name": "distilbert-base-uncased",
"page_title": "arXiv Topic Classifier",
"page_icon": "📚",
"example_title": "Learning-based Visual Navigation for Mobile Robots",
"example_abstract": (
"We present a transformer-based navigation system that uses camera observations "
"and scene understanding to plan robust trajectories for indoor mobile robots."
),
}
def load_app_config() -> dict[str, Any]:
if not CONFIG_PATH.exists():
return DEFAULT_APP_CONFIG.copy()
with CONFIG_PATH.open("r", encoding="utf-8") as fh:
config = json.load(fh)
merged_config = DEFAULT_APP_CONFIG.copy()
merged_config.update(config)
return merged_config
APP_CONFIG = load_app_config()
MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"])
LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"])
MAX_LENGTH = int(APP_CONFIG["max_length"])
COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"])
st.set_page_config(
page_title=str(APP_CONFIG["page_title"]),
page_icon=str(APP_CONFIG["page_icon"]),
layout="centered",
)
@st.cache_resource
def load_classifier() -> ArticleClassifier:
return ArticleClassifier(
model_dir=MODEL_DIR,
labels_path=LABELS_PATH,
max_length=MAX_LENGTH,
)
@st.cache_data
def load_metrics() -> dict | None:
if not METRICS_PATH.exists():
return None
import json
with METRICS_PATH.open("r", encoding="utf-8") as fh:
return json.load(fh)
def format_probability(probability: float) -> str:
return f"{probability * 100:.2f}%"
def format_threshold(threshold: float) -> str:
return f"{threshold * 100:.0f}%"
def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None:
for index, item in enumerate(predictions, start=1):
label = str(item["label"])
probability = float(item["probability"])
st.write(f"{index}. `{label}`")
st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability))
def main() -> None:
coverage_label = format_threshold(COVERAGE_THRESHOLD)
st.title(str(APP_CONFIG["page_title"]))
st.write(
"This demo predicts arXiv paper topics from the title and abstract using a transformer classifier."
)
st.caption(
"For homework evaluation, the app returns the smallest prefix of categories whose cumulative "
f"probability reaches {coverage_label}."
)
st.info(
"How to test: paste a real or synthetic paper title, optionally add an abstract, and press "
"`Predict categories`. If the abstract is empty, the model will classify from the title only."
)
classifier: ArticleClassifier | None = None
classifier_load_error: str | None = None
with st.sidebar:
try:
classifier = load_classifier()
except Exception as exc:
classifier_load_error = f"Model initialization error in load_classifier: {exc}"
metrics = load_metrics()
st.subheader("Evaluation Summary")
st.write(f"Model: `{APP_CONFIG['model_name']}`")
if classifier is not None:
st.write(f"Number of classes: `{len(classifier.labels)}`")
st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels))
else:
st.error(classifier_load_error or "Model initialization error: unknown error")
if metrics is not None:
validation_accuracy = metrics.get("validation", {}).get("eval_accuracy")
validation_f1 = metrics.get("validation", {}).get("eval_macro_f1")
test_accuracy = metrics.get("test", {}).get("test_accuracy")
test_f1 = metrics.get("test", {}).get("test_macro_f1")
if validation_accuracy is not None:
st.write(f"Validation accuracy: `{validation_accuracy:.4f}`")
if validation_f1 is not None:
st.write(f"Validation macro-F1: `{validation_f1:.4f}`")
if test_accuracy is not None:
st.write(f"Test accuracy: `{test_accuracy:.4f}`")
if test_f1 is not None:
st.write(f"Test macro-F1: `{test_f1:.4f}`")
st.write(
"Output rule: return categories until cumulative probability reaches "
f"{coverage_label}"
)
with st.expander("Example Input For Quick Check"):
st.markdown(
f"**Title:** {APP_CONFIG['example_title']}\n\n"
f"**Abstract:** {APP_CONFIG['example_abstract']}"
)
with st.form("prediction_form"):
title = st.text_input(
"Article title",
placeholder="Enter the article title",
)
abstract = st.text_area(
"Abstract",
placeholder="Enter the abstract (optional, but recommended)",
height=220,
)
predict_button = st.form_submit_button("Predict categories", type="primary")
if predict_button:
if classifier is None:
st.error(classifier_load_error or "Model initialization error: classifier is unavailable.")
return
if not title.strip() and not abstract.strip():
st.error("Input validation error in app: please enter at least a title or an abstract.")
return
with st.spinner("Running inference..."):
try:
full_predictions = classifier.predict(title=title, abstract=abstract)
predictions = classifier.select_top_k_by_probability_mass(
full_predictions,
threshold=COVERAGE_THRESHOLD,
)
except ValueError as exc:
st.error(str(exc))
return
except ClassifierError as exc:
st.error(f"Classifier error in prediction flow: {exc}")
return
except Exception as exc:
st.error(f"Unexpected inference error in app.main: {exc}")
return
best_prediction = predictions[0]
covered_probability = sum(float(item["probability"]) for item in predictions)
col1, col2, col3 = st.columns(3)
col1.metric("Top class", str(best_prediction["label"]))
col2.metric("Top probability", format_probability(float(best_prediction["probability"])))
col3.metric("Top-95% coverage", format_probability(covered_probability))
st.subheader("Top categories")
st.caption(
f"These are the categories returned by the assignment top-{coverage_label} rule."
)
render_prediction_rows(predictions)
with st.expander("Show Full Ranking"):
render_prediction_rows(full_predictions)
if __name__ == "__main__":
main()