Spaces:
No application file
No application file
| from __future__ import annotations | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| classification_report, | |
| confusion_matrix, | |
| f1_score, | |
| precision_score, | |
| recall_score, | |
| ) | |
| # External plotting libs | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.express as px | |
| # Ag-Grid for the data-explorer | |
| from st_aggrid import AgGrid, GridOptionsBuilder | |
| ############################################################################### | |
| # ------------------------------ APP HELPERS -------------------------------- | |
| ############################################################################### | |
| def _load_data(uploaded_file: st.runtime.uploaded_file_manager.UploadedFile | None) -> pd.DataFrame | None: | |
| """Load XLSX or CSV into a DataFrame, or return *None* if not uploaded.""" | |
| if uploaded_file is None: | |
| return None | |
| file_name = uploaded_file.name.lower() | |
| try: | |
| if file_name.endswith((".xlsx", ".xls")): | |
| return pd.read_excel(uploaded_file) | |
| if file_name.endswith(".csv"): | |
| return pd.read_csv(uploaded_file) | |
| except Exception as exc: # pragma: no-cover | |
| st.error(f"Could not read the uploaded file β {exc}") | |
| return None | |
| st.error("Unsupported file type. Please upload .xlsx or .csv.") | |
| return None | |
| def _compute_metrics( | |
| df: pd.DataFrame, | |
| y_true_col: str, | |
| y_pred_col: str, | |
| ): | |
| """Return global metrics, class report & confusion matrix.""" | |
| y_true = df[y_true_col].astype(str).fillna("<NA>") | |
| y_pred = df[y_pred_col].astype(str).fillna("<NA>") | |
| acc = accuracy_score(y_true, y_pred) | |
| prec = precision_score(y_true, y_pred, average="weighted", zero_division=0) | |
| rec = recall_score(y_true, y_pred, average="weighted", zero_division=0) | |
| f1 = f1_score(y_true, y_pred, average="macro", zero_division=0) | |
| cls_report = classification_report( | |
| y_true, y_pred, output_dict=True, zero_division=0 | |
| ) | |
| labels = sorted(y_true.unique().tolist()) | |
| conf_mat = confusion_matrix(y_true, y_pred, labels=labels) | |
| return acc, prec, rec, f1, cls_report, conf_mat, labels | |
| def _plot_confusion(conf_mat: np.ndarray, labels: list[str]): | |
| """Return a seaborn heat-map figure with readable tick labels.""" | |
| # Dynamic sizing β wider for x-labels, taller for y-labels | |
| fig_w = max(8, 0.4 * len(labels)) # width grows slowly | |
| fig_h = max(6, 0.35 * len(labels)) # height a bit shorter | |
| fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
| sns.heatmap( | |
| conf_mat, | |
| annot=True, | |
| fmt="d", | |
| cmap="Blues", | |
| xticklabels=labels, | |
| yticklabels=labels, | |
| ax=ax, | |
| cbar_kws={"shrink": 0.85}, | |
| ) | |
| # Rotate & style tick labels for readability | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8) | |
| ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=8) | |
| ax.set_xlabel("Predicted Label") | |
| ax.set_ylabel("True Label") | |
| ax.set_title("Confusion Matrix") | |
| fig.tight_layout() | |
| return fig | |
| ############################################################################### | |
| # --------------------------------- MAIN ----------------------------------- | |
| ############################################################################### | |
| def main() -> None: | |
| st.set_page_config( | |
| page_title="ML Prediction Dashboard", | |
| layout="wide", | |
| page_icon="π", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.title("π Machine-Learning Prediction Dashboard") | |
| st.write( | |
| "Upload a predictions file and instantly explore model performance, " | |
| "confidence behaviour and individual mis-classifications." | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Sidebar β file upload & column mapping | |
| # ------------------------------------------------------------------ | |
| with st.sidebar: | |
| st.header("1οΈβ£ Upload & Mapping") | |
| uploaded_file = st.file_uploader( | |
| "Upload .xlsx or .csv containing predictions", type=["xlsx", "xls", "csv"] | |
| ) | |
| st.divider() | |
| st.header("2οΈβ£ Column Mapping") | |
| y_true_col = st.text_input("Ground-truth column", value="ground_truth") | |
| y_pred_col = st.text_input("Predicted-label column", value="CASISTICA_MOTIVAZIONE") | |
| prob_col = st.text_input( | |
| "Probability / confidence column", value="PROBABILITA_ASSOCIAZIONE" | |
| ) | |
| df = _load_data(uploaded_file) | |
| if df is None: | |
| st.info("π Upload a file to start β¦") | |
| st.stop() | |
| # ------------------------------------------------------------------ | |
| # KPI Metrics | |
| # ------------------------------------------------------------------ | |
| acc, prec, rec, f1, cls_report, conf_mat, labels = _compute_metrics( | |
| df, y_true_col, y_pred_col | |
| ) | |
| kpi_cols = st.columns(6) | |
| kpi_cols[0].metric("Accuracy", f"{acc:.2%}") | |
| kpi_cols[1].metric("Weighted Precision", f"{prec:.2%}") | |
| kpi_cols[2].metric("Weighted Recall", f"{rec:.2%}") | |
| kpi_cols[3].metric("Macro-F1", f"{f1:.2%}") | |
| kpi_cols[4].metric("# Records", f"{len(df):,}") | |
| kpi_cols[5].metric("# Classes", f"{df[y_true_col].nunique()}") | |
| st.divider() | |
| # ------------------------------------------------------------------ | |
| # Confidence distribution + threshold sweeper | |
| # ------------------------------------------------------------------ | |
| st.subheader("Confidence Distribution") | |
| if prob_col in df.columns: | |
| fig_hist = px.histogram( | |
| df, | |
| x=prob_col, | |
| nbins=40, | |
| marginal="box", | |
| title="Model confidence histogram", | |
| labels={prob_col: "Confidence"}, | |
| height=350, | |
| ) | |
| st.plotly_chart(fig_hist, use_container_width=True) | |
| st.markdown("#### Threshold Sweeper") | |
| thresh = st.slider("Probability threshold", 0.0, 1.0, 0.5, 0.01) | |
| df_tmp = df.copy() | |
| df_tmp["_adjusted_pred"] = np.where( | |
| df_tmp[prob_col] >= thresh, df_tmp[y_pred_col].astype(str), "UNASSIGNED" | |
| ) | |
| acc2, prec2, rec2, f12, *_ = _compute_metrics(df_tmp, y_true_col, "_adjusted_pred") | |
| st.info( | |
| f"**Metrics @ β₯ {thresh:.2f}** β " | |
| f"Accuracy {acc2:.2%} β’ Precision {prec2:.2%} β’ " | |
| f"Recall {rec2:.2%} β’ Macro-F1 {f12:.2%}" | |
| ) | |
| else: | |
| st.warning("Selected probability column does not exist β skipping confidence plots.") | |
| st.divider() | |
| # ------------------------------------------------------------------ | |
| # Confusion matrix & class-wise report | |
| # ------------------------------------------------------------------ | |
| st.subheader("Confusion Matrix") | |
| fig_cm = _plot_confusion(conf_mat, labels) | |
| st.pyplot(fig_cm, use_container_width=True) | |
| st.subheader("Class-wise Metrics") | |
| cls_df = ( | |
| pd.DataFrame(cls_report) | |
| .T.reset_index() | |
| .rename(columns={"index": "class"}) | |
| ) | |
| st.dataframe(cls_df, use_container_width=True) | |
| st.divider() | |
| # ------------------------------------------------------------------ | |
| # Data Explorer (AG-Grid) β with text wrapping & interactive reordering | |
| # ------------------------------------------------------------------ | |
| st.subheader("Data Explorer") | |
| # Filters | |
| with st.expander("Filters", expanded=False): | |
| sel_true = st.multiselect( | |
| "Ground-truth labels β", sorted(df[y_true_col].unique()), | |
| default=sorted(df[y_true_col].unique()), | |
| ) | |
| sel_pred = st.multiselect( | |
| "Predicted labels β", sorted(df[y_pred_col].unique()), | |
| default=sorted(df[y_pred_col].unique()), | |
| ) | |
| if prob_col in df.columns: | |
| prob_rng = st.slider( | |
| "Confidence range β", 0.0, 1.0, (0.0, 1.0), 0.01, key="prob_range" | |
| ) | |
| else: | |
| prob_rng = (0.0, 1.0) | |
| # Apply filters | |
| df_view = df[ | |
| df[y_true_col].isin(sel_true) | |
| & df[y_pred_col].isin(sel_pred) | |
| & ( | |
| (df[prob_col] >= prob_rng[0]) & (df[prob_col] <= prob_rng[1]) | |
| if prob_col in df.columns | |
| else True | |
| ) | |
| ].copy() | |
| st.caption(f"Showing **{len(df_view):,}** rows after filtering.") | |
| # Build AgGrid table with wrapping & movable columns | |
| gb = GridOptionsBuilder.from_dataframe(df_view) | |
| gb.configure_default_column( | |
| editable=False, | |
| filter=True, | |
| sortable=True, | |
| resizable=True, | |
| wrapText=True, | |
| autoHeight=True, | |
| movable=True, # allow drag-and-drop | |
| ) | |
| # Optional: give extra width to your free-text column | |
| if "NOTE_OPERATORE" in df_view.columns: | |
| gb.configure_column( | |
| "NOTE_OPERATORE", | |
| width=300, | |
| minWidth=100, | |
| maxWidth=600, | |
| wrapText=True, | |
| autoHeight=True, | |
| ) | |
| gb.configure_selection("single", use_checkbox=True) | |
| grid_opts = gb.build() | |
| grid_opts["suppressMovableColumns"] = False | |
| AgGrid( | |
| df_view, | |
| gridOptions=grid_opts, | |
| enable_enterprise_modules=True, | |
| height=400, | |
| width="100%", | |
| allow_unsafe_jscode=True, | |
| update_mode="SELECTION_CHANGED", | |
| ) | |
| # Selected-row details as before... | |
| grid_resp = st.session_state.get("grid_response", None) | |
| sel = grid_resp["selected_rows"] if grid_resp else [] | |
| if sel: | |
| row = sel[0] | |
| st.markdown("### Row Details") | |
| with st.expander(f"Document #: {row.get('NUMERO_DOCUMENTO','N/A')}", expanded=True): | |
| st.write("**Ground-truth:**", row.get(y_true_col)) | |
| st.write("**Predicted:**", row.get(y_pred_col)) | |
| if prob_col in row: | |
| st.write("**Confidence:**", row.get(prob_col)) | |
| st.write("**Operator Notes:**") | |
| st.write(row.get("NOTE_OPERATORE", "β")) | |
| match_cols = [c for c in df.columns if c.startswith("MATCH") and not c.endswith("VALUE")] | |
| if match_cols: | |
| st.write("**Top Suggestions & Similarity**") | |
| sim_df = pd.DataFrame( | |
| { | |
| "Suggestion": [row.get(c) for c in match_cols], | |
| "Similarity": [ | |
| row.get(f"{c}_VALUE") if f"{c}_VALUE" in row else np.nan | |
| for c in match_cols | |
| ], | |
| } | |
| ) | |
| st.table(sim_df) | |
| if __name__ == "__main__": | |
| main() | |