Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any | |
| import streamlit as st | |
| from inference import ArticleClassifier, ClassifierError | |
| PROJECT_DIR = Path(__file__).resolve().parent | |
| CONFIG_PATH = PROJECT_DIR / "configs" / "app_config.json" | |
| METRICS_PATH = PROJECT_DIR / "artifacts" / "large_model" / "metrics.json" | |
| DEFAULT_APP_CONFIG = { | |
| "model_dir": "artifacts/large_model/best_model", | |
| "labels_path": "data/processed_large/label_mapping.json", | |
| "max_length": 256, | |
| "coverage_threshold": 0.95, | |
| "model_name": "distilbert-base-uncased", | |
| "page_title": "arXiv Topic Classifier", | |
| "page_icon": "📚", | |
| "example_title": "Learning-based Visual Navigation for Mobile Robots", | |
| "example_abstract": ( | |
| "We present a transformer-based navigation system that uses camera observations " | |
| "and scene understanding to plan robust trajectories for indoor mobile robots." | |
| ), | |
| } | |
| def load_app_config() -> dict[str, Any]: | |
| if not CONFIG_PATH.exists(): | |
| return DEFAULT_APP_CONFIG.copy() | |
| with CONFIG_PATH.open("r", encoding="utf-8") as fh: | |
| config = json.load(fh) | |
| merged_config = DEFAULT_APP_CONFIG.copy() | |
| merged_config.update(config) | |
| return merged_config | |
| APP_CONFIG = load_app_config() | |
| MODEL_DIR = PROJECT_DIR / str(APP_CONFIG["model_dir"]) | |
| LABELS_PATH = PROJECT_DIR / str(APP_CONFIG["labels_path"]) | |
| MAX_LENGTH = int(APP_CONFIG["max_length"]) | |
| COVERAGE_THRESHOLD = float(APP_CONFIG["coverage_threshold"]) | |
| st.set_page_config( | |
| page_title=str(APP_CONFIG["page_title"]), | |
| page_icon=str(APP_CONFIG["page_icon"]), | |
| layout="centered", | |
| ) | |
| def load_classifier() -> ArticleClassifier: | |
| return ArticleClassifier( | |
| model_dir=MODEL_DIR, | |
| labels_path=LABELS_PATH, | |
| max_length=MAX_LENGTH, | |
| ) | |
| def load_metrics() -> dict | None: | |
| if not METRICS_PATH.exists(): | |
| return None | |
| import json | |
| with METRICS_PATH.open("r", encoding="utf-8") as fh: | |
| return json.load(fh) | |
| def format_probability(probability: float) -> str: | |
| return f"{probability * 100:.2f}%" | |
| def format_threshold(threshold: float) -> str: | |
| return f"{threshold * 100:.0f}%" | |
| def render_prediction_rows(predictions: list[dict[str, float | str]]) -> None: | |
| for index, item in enumerate(predictions, start=1): | |
| label = str(item["label"]) | |
| probability = float(item["probability"]) | |
| st.write(f"{index}. `{label}`") | |
| st.progress(min(max(probability, 0.0), 1.0), text=format_probability(probability)) | |
| def main() -> None: | |
| coverage_label = format_threshold(COVERAGE_THRESHOLD) | |
| st.title(str(APP_CONFIG["page_title"])) | |
| st.write( | |
| "This demo predicts arXiv paper topics from the title and abstract using a transformer classifier." | |
| ) | |
| st.caption( | |
| "For homework evaluation, the app returns the smallest prefix of categories whose cumulative " | |
| f"probability reaches {coverage_label}." | |
| ) | |
| st.info( | |
| "How to test: paste a real or synthetic paper title, optionally add an abstract, and press " | |
| "`Predict categories`. If the abstract is empty, the model will classify from the title only." | |
| ) | |
| classifier: ArticleClassifier | None = None | |
| classifier_load_error: str | None = None | |
| with st.sidebar: | |
| try: | |
| classifier = load_classifier() | |
| except Exception as exc: | |
| classifier_load_error = f"Model initialization error in load_classifier: {exc}" | |
| metrics = load_metrics() | |
| st.subheader("Evaluation Summary") | |
| st.write(f"Model: `{APP_CONFIG['model_name']}`") | |
| if classifier is not None: | |
| st.write(f"Number of classes: `{len(classifier.labels)}`") | |
| st.write("Classes: " + ", ".join(f"`{label}`" for label in classifier.labels)) | |
| else: | |
| st.error(classifier_load_error or "Model initialization error: unknown error") | |
| if metrics is not None: | |
| validation_accuracy = metrics.get("validation", {}).get("eval_accuracy") | |
| validation_f1 = metrics.get("validation", {}).get("eval_macro_f1") | |
| test_accuracy = metrics.get("test", {}).get("test_accuracy") | |
| test_f1 = metrics.get("test", {}).get("test_macro_f1") | |
| if validation_accuracy is not None: | |
| st.write(f"Validation accuracy: `{validation_accuracy:.4f}`") | |
| if validation_f1 is not None: | |
| st.write(f"Validation macro-F1: `{validation_f1:.4f}`") | |
| if test_accuracy is not None: | |
| st.write(f"Test accuracy: `{test_accuracy:.4f}`") | |
| if test_f1 is not None: | |
| st.write(f"Test macro-F1: `{test_f1:.4f}`") | |
| st.write( | |
| "Output rule: return categories until cumulative probability reaches " | |
| f"{coverage_label}" | |
| ) | |
| with st.expander("Example Input For Quick Check"): | |
| st.markdown( | |
| f"**Title:** {APP_CONFIG['example_title']}\n\n" | |
| f"**Abstract:** {APP_CONFIG['example_abstract']}" | |
| ) | |
| with st.form("prediction_form"): | |
| title = st.text_input( | |
| "Article title", | |
| placeholder="Enter the article title", | |
| ) | |
| abstract = st.text_area( | |
| "Abstract", | |
| placeholder="Enter the abstract (optional, but recommended)", | |
| height=220, | |
| ) | |
| predict_button = st.form_submit_button("Predict categories", type="primary") | |
| if predict_button: | |
| if classifier is None: | |
| st.error(classifier_load_error or "Model initialization error: classifier is unavailable.") | |
| return | |
| if not title.strip() and not abstract.strip(): | |
| st.error("Input validation error in app: please enter at least a title or an abstract.") | |
| return | |
| with st.spinner("Running inference..."): | |
| try: | |
| full_predictions = classifier.predict(title=title, abstract=abstract) | |
| predictions = classifier.select_top_k_by_probability_mass( | |
| full_predictions, | |
| threshold=COVERAGE_THRESHOLD, | |
| ) | |
| except ValueError as exc: | |
| st.error(str(exc)) | |
| return | |
| except ClassifierError as exc: | |
| st.error(f"Classifier error in prediction flow: {exc}") | |
| return | |
| except Exception as exc: | |
| st.error(f"Unexpected inference error in app.main: {exc}") | |
| return | |
| best_prediction = predictions[0] | |
| covered_probability = sum(float(item["probability"]) for item in predictions) | |
| col1, col2, col3 = st.columns(3) | |
| col1.metric("Top class", str(best_prediction["label"])) | |
| col2.metric("Top probability", format_probability(float(best_prediction["probability"]))) | |
| col3.metric("Top-95% coverage", format_probability(covered_probability)) | |
| st.subheader("Top categories") | |
| st.caption( | |
| f"These are the categories returned by the assignment top-{coverage_label} rule." | |
| ) | |
| render_prediction_rows(predictions) | |
| with st.expander("Show Full Ranking"): | |
| render_prediction_rows(full_predictions) | |
| if __name__ == "__main__": | |
| main() | |