from __future__ import annotations import numpy as np import pandas as pd import streamlit as st from sklearn.metrics import ( accuracy_score, classification_report, confusion_matrix, f1_score, precision_score, recall_score, ) # External plotting libs import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px # Ag-Grid for the data-explorer from st_aggrid import AgGrid, GridOptionsBuilder ############################################################################### # ------------------------------ APP HELPERS -------------------------------- ############################################################################### def _load_data(uploaded_file: st.runtime.uploaded_file_manager.UploadedFile | None) -> pd.DataFrame | None: """Load XLSX or CSV into a DataFrame, or return *None* if not uploaded.""" if uploaded_file is None: return None file_name = uploaded_file.name.lower() try: if file_name.endswith((".xlsx", ".xls")): return pd.read_excel(uploaded_file) if file_name.endswith(".csv"): return pd.read_csv(uploaded_file) except Exception as exc: # pragma: no-cover st.error(f"Could not read the uploaded file – {exc}") return None st.error("Unsupported file type. Please upload .xlsx or .csv.") return None def _compute_metrics( df: pd.DataFrame, y_true_col: str, y_pred_col: str, ): """Return global metrics, class report & confusion matrix.""" y_true = df[y_true_col].astype(str).fillna("") y_pred = df[y_pred_col].astype(str).fillna("") acc = accuracy_score(y_true, y_pred) prec = precision_score(y_true, y_pred, average="weighted", zero_division=0) rec = recall_score(y_true, y_pred, average="weighted", zero_division=0) f1 = f1_score(y_true, y_pred, average="macro", zero_division=0) cls_report = classification_report( y_true, y_pred, output_dict=True, zero_division=0 ) labels = sorted(y_true.unique().tolist()) conf_mat = confusion_matrix(y_true, y_pred, labels=labels) return acc, prec, rec, f1, cls_report, conf_mat, labels def _plot_confusion(conf_mat: np.ndarray, labels: list[str]): """Return a seaborn heat-map figure with readable tick labels.""" # Dynamic sizing – wider for x-labels, taller for y-labels fig_w = max(8, 0.4 * len(labels)) # width grows slowly fig_h = max(6, 0.35 * len(labels)) # height a bit shorter fig, ax = plt.subplots(figsize=(fig_w, fig_h)) sns.heatmap( conf_mat, annot=True, fmt="d", cmap="Blues", xticklabels=labels, yticklabels=labels, ax=ax, cbar_kws={"shrink": 0.85}, ) # Rotate & style tick labels for readability ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8) ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=8) ax.set_xlabel("Predicted Label") ax.set_ylabel("True Label") ax.set_title("Confusion Matrix") fig.tight_layout() return fig ############################################################################### # --------------------------------- MAIN ----------------------------------- ############################################################################### def main() -> None: st.set_page_config( page_title="ML Prediction Dashboard", layout="wide", page_icon="📊", initial_sidebar_state="expanded", ) st.title("📊 Machine-Learning Prediction Dashboard") st.write( "Upload a predictions file and instantly explore model performance, " "confidence behaviour and individual mis-classifications." ) # ------------------------------------------------------------------ # Sidebar – file upload & column mapping # ------------------------------------------------------------------ with st.sidebar: st.header("1️⃣ Upload & Mapping") uploaded_file = st.file_uploader( "Upload .xlsx or .csv containing predictions", type=["xlsx", "xls", "csv"] ) st.divider() st.header("2️⃣ Column Mapping") y_true_col = st.text_input("Ground-truth column", value="ground_truth") y_pred_col = st.text_input("Predicted-label column", value="CASISTICA_MOTIVAZIONE") prob_col = st.text_input( "Probability / confidence column", value="PROBABILITA_ASSOCIAZIONE" ) df = _load_data(uploaded_file) if df is None: st.info("👈 Upload a file to start …") st.stop() # ------------------------------------------------------------------ # KPI Metrics # ------------------------------------------------------------------ acc, prec, rec, f1, cls_report, conf_mat, labels = _compute_metrics( df, y_true_col, y_pred_col ) kpi_cols = st.columns(6) kpi_cols[0].metric("Accuracy", f"{acc:.2%}") kpi_cols[1].metric("Weighted Precision", f"{prec:.2%}") kpi_cols[2].metric("Weighted Recall", f"{rec:.2%}") kpi_cols[3].metric("Macro-F1", f"{f1:.2%}") kpi_cols[4].metric("# Records", f"{len(df):,}") kpi_cols[5].metric("# Classes", f"{df[y_true_col].nunique()}") st.divider() # ------------------------------------------------------------------ # Confidence distribution + threshold sweeper # ------------------------------------------------------------------ st.subheader("Confidence Distribution") if prob_col in df.columns: fig_hist = px.histogram( df, x=prob_col, nbins=40, marginal="box", title="Model confidence histogram", labels={prob_col: "Confidence"}, height=350, ) st.plotly_chart(fig_hist, use_container_width=True) st.markdown("#### Threshold Sweeper") thresh = st.slider("Probability threshold", 0.0, 1.0, 0.5, 0.01) df_tmp = df.copy() df_tmp["_adjusted_pred"] = np.where( df_tmp[prob_col] >= thresh, df_tmp[y_pred_col].astype(str), "UNASSIGNED" ) acc2, prec2, rec2, f12, *_ = _compute_metrics(df_tmp, y_true_col, "_adjusted_pred") st.info( f"**Metrics @ ≥ {thresh:.2f}** — " f"Accuracy {acc2:.2%} • Precision {prec2:.2%} • " f"Recall {rec2:.2%} • Macro-F1 {f12:.2%}" ) else: st.warning("Selected probability column does not exist – skipping confidence plots.") st.divider() # ------------------------------------------------------------------ # Confusion matrix & class-wise report # ------------------------------------------------------------------ st.subheader("Confusion Matrix") fig_cm = _plot_confusion(conf_mat, labels) st.pyplot(fig_cm, use_container_width=True) st.subheader("Class-wise Metrics") cls_df = ( pd.DataFrame(cls_report) .T.reset_index() .rename(columns={"index": "class"}) ) st.dataframe(cls_df, use_container_width=True) st.divider() # ------------------------------------------------------------------ # Data Explorer (AG-Grid) – with text wrapping & interactive reordering # ------------------------------------------------------------------ st.subheader("Data Explorer") # Filters with st.expander("Filters", expanded=False): sel_true = st.multiselect( "Ground-truth labels ➟", sorted(df[y_true_col].unique()), default=sorted(df[y_true_col].unique()), ) sel_pred = st.multiselect( "Predicted labels ➟", sorted(df[y_pred_col].unique()), default=sorted(df[y_pred_col].unique()), ) if prob_col in df.columns: prob_rng = st.slider( "Confidence range ➟", 0.0, 1.0, (0.0, 1.0), 0.01, key="prob_range" ) else: prob_rng = (0.0, 1.0) # Apply filters df_view = df[ df[y_true_col].isin(sel_true) & df[y_pred_col].isin(sel_pred) & ( (df[prob_col] >= prob_rng[0]) & (df[prob_col] <= prob_rng[1]) if prob_col in df.columns else True ) ].copy() st.caption(f"Showing **{len(df_view):,}** rows after filtering.") # Build AgGrid table with wrapping & movable columns gb = GridOptionsBuilder.from_dataframe(df_view) gb.configure_default_column( editable=False, filter=True, sortable=True, resizable=True, wrapText=True, autoHeight=True, movable=True, # allow drag-and-drop ) # Optional: give extra width to your free-text column if "NOTE_OPERATORE" in df_view.columns: gb.configure_column( "NOTE_OPERATORE", width=300, minWidth=100, maxWidth=600, wrapText=True, autoHeight=True, ) gb.configure_selection("single", use_checkbox=True) grid_opts = gb.build() grid_opts["suppressMovableColumns"] = False AgGrid( df_view, gridOptions=grid_opts, enable_enterprise_modules=True, height=400, width="100%", allow_unsafe_jscode=True, update_mode="SELECTION_CHANGED", ) # Selected-row details as before... grid_resp = st.session_state.get("grid_response", None) sel = grid_resp["selected_rows"] if grid_resp else [] if sel: row = sel[0] st.markdown("### Row Details") with st.expander(f"Document #: {row.get('NUMERO_DOCUMENTO','N/A')}", expanded=True): st.write("**Ground-truth:**", row.get(y_true_col)) st.write("**Predicted:**", row.get(y_pred_col)) if prob_col in row: st.write("**Confidence:**", row.get(prob_col)) st.write("**Operator Notes:**") st.write(row.get("NOTE_OPERATORE", "—")) match_cols = [c for c in df.columns if c.startswith("MATCH") and not c.endswith("VALUE")] if match_cols: st.write("**Top Suggestions & Similarity**") sim_df = pd.DataFrame( { "Suggestion": [row.get(c) for c in match_cols], "Similarity": [ row.get(f"{c}_VALUE") if f"{c}_VALUE" in row else np.nan for c in match_cols ], } ) st.table(sim_df) if __name__ == "__main__": main()