Spaces:
No application file
No application file
from __future__ import annotations | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from sklearn.metrics import ( | |
accuracy_score, | |
classification_report, | |
confusion_matrix, | |
f1_score, | |
precision_score, | |
recall_score, | |
) | |
# External plotting libs | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import plotly.express as px | |
# Ag-Grid for the data-explorer | |
from st_aggrid import AgGrid, GridOptionsBuilder | |
############################################################################### | |
# ------------------------------ APP HELPERS -------------------------------- | |
############################################################################### | |
def _load_data(uploaded_file: st.runtime.uploaded_file_manager.UploadedFile | None) -> pd.DataFrame | None: | |
"""Load XLSX or CSV into a DataFrame, or return *None* if not uploaded.""" | |
if uploaded_file is None: | |
return None | |
file_name = uploaded_file.name.lower() | |
try: | |
if file_name.endswith((".xlsx", ".xls")): | |
return pd.read_excel(uploaded_file) | |
if file_name.endswith(".csv"): | |
return pd.read_csv(uploaded_file) | |
except Exception as exc: # pragma: no-cover | |
st.error(f"Could not read the uploaded file β {exc}") | |
return None | |
st.error("Unsupported file type. Please upload .xlsx or .csv.") | |
return None | |
def _compute_metrics( | |
df: pd.DataFrame, | |
y_true_col: str, | |
y_pred_col: str, | |
): | |
"""Return global metrics, class report & confusion matrix.""" | |
y_true = df[y_true_col].astype(str).fillna("<NA>") | |
y_pred = df[y_pred_col].astype(str).fillna("<NA>") | |
acc = accuracy_score(y_true, y_pred) | |
prec = precision_score(y_true, y_pred, average="weighted", zero_division=0) | |
rec = recall_score(y_true, y_pred, average="weighted", zero_division=0) | |
f1 = f1_score(y_true, y_pred, average="macro", zero_division=0) | |
cls_report = classification_report( | |
y_true, y_pred, output_dict=True, zero_division=0 | |
) | |
labels = sorted(y_true.unique().tolist()) | |
conf_mat = confusion_matrix(y_true, y_pred, labels=labels) | |
return acc, prec, rec, f1, cls_report, conf_mat, labels | |
def _plot_confusion(conf_mat: np.ndarray, labels: list[str]): | |
"""Return a seaborn heat-map figure with readable tick labels.""" | |
# Dynamic sizing β wider for x-labels, taller for y-labels | |
fig_w = max(8, 0.4 * len(labels)) # width grows slowly | |
fig_h = max(6, 0.35 * len(labels)) # height a bit shorter | |
fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
sns.heatmap( | |
conf_mat, | |
annot=True, | |
fmt="d", | |
cmap="Blues", | |
xticklabels=labels, | |
yticklabels=labels, | |
ax=ax, | |
cbar_kws={"shrink": 0.85}, | |
) | |
# Rotate & style tick labels for readability | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right", fontsize=8) | |
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=8) | |
ax.set_xlabel("Predicted Label") | |
ax.set_ylabel("True Label") | |
ax.set_title("Confusion Matrix") | |
fig.tight_layout() | |
return fig | |
############################################################################### | |
# --------------------------------- MAIN ----------------------------------- | |
############################################################################### | |
def main() -> None: | |
st.set_page_config( | |
page_title="ML Prediction Dashboard", | |
layout="wide", | |
page_icon="π", | |
initial_sidebar_state="expanded", | |
) | |
st.title("π Machine-Learning Prediction Dashboard") | |
st.write( | |
"Upload a predictions file and instantly explore model performance, " | |
"confidence behaviour and individual mis-classifications." | |
) | |
# ------------------------------------------------------------------ | |
# Sidebar β file upload & column mapping | |
# ------------------------------------------------------------------ | |
with st.sidebar: | |
st.header("1οΈβ£ Upload & Mapping") | |
uploaded_file = st.file_uploader( | |
"Upload .xlsx or .csv containing predictions", type=["xlsx", "xls", "csv"] | |
) | |
st.divider() | |
st.header("2οΈβ£ Column Mapping") | |
y_true_col = st.text_input("Ground-truth column", value="ground_truth") | |
y_pred_col = st.text_input("Predicted-label column", value="CASISTICA_MOTIVAZIONE") | |
prob_col = st.text_input( | |
"Probability / confidence column", value="PROBABILITA_ASSOCIAZIONE" | |
) | |
df = _load_data(uploaded_file) | |
if df is None: | |
st.info("π Upload a file to start β¦") | |
st.stop() | |
# ------------------------------------------------------------------ | |
# KPI Metrics | |
# ------------------------------------------------------------------ | |
acc, prec, rec, f1, cls_report, conf_mat, labels = _compute_metrics( | |
df, y_true_col, y_pred_col | |
) | |
kpi_cols = st.columns(6) | |
kpi_cols[0].metric("Accuracy", f"{acc:.2%}") | |
kpi_cols[1].metric("Weighted Precision", f"{prec:.2%}") | |
kpi_cols[2].metric("Weighted Recall", f"{rec:.2%}") | |
kpi_cols[3].metric("Macro-F1", f"{f1:.2%}") | |
kpi_cols[4].metric("# Records", f"{len(df):,}") | |
kpi_cols[5].metric("# Classes", f"{df[y_true_col].nunique()}") | |
st.divider() | |
# ------------------------------------------------------------------ | |
# Confidence distribution + threshold sweeper | |
# ------------------------------------------------------------------ | |
st.subheader("Confidence Distribution") | |
if prob_col in df.columns: | |
fig_hist = px.histogram( | |
df, | |
x=prob_col, | |
nbins=40, | |
marginal="box", | |
title="Model confidence histogram", | |
labels={prob_col: "Confidence"}, | |
height=350, | |
) | |
st.plotly_chart(fig_hist, use_container_width=True) | |
st.markdown("#### Threshold Sweeper") | |
thresh = st.slider("Probability threshold", 0.0, 1.0, 0.5, 0.01) | |
df_tmp = df.copy() | |
df_tmp["_adjusted_pred"] = np.where( | |
df_tmp[prob_col] >= thresh, df_tmp[y_pred_col].astype(str), "UNASSIGNED" | |
) | |
acc2, prec2, rec2, f12, *_ = _compute_metrics(df_tmp, y_true_col, "_adjusted_pred") | |
st.info( | |
f"**Metrics @ β₯ {thresh:.2f}** β " | |
f"Accuracy {acc2:.2%} β’ Precision {prec2:.2%} β’ " | |
f"Recall {rec2:.2%} β’ Macro-F1 {f12:.2%}" | |
) | |
else: | |
st.warning("Selected probability column does not exist β skipping confidence plots.") | |
st.divider() | |
# ------------------------------------------------------------------ | |
# Confusion matrix & class-wise report | |
# ------------------------------------------------------------------ | |
st.subheader("Confusion Matrix") | |
fig_cm = _plot_confusion(conf_mat, labels) | |
st.pyplot(fig_cm, use_container_width=True) | |
st.subheader("Class-wise Metrics") | |
cls_df = ( | |
pd.DataFrame(cls_report) | |
.T.reset_index() | |
.rename(columns={"index": "class"}) | |
) | |
st.dataframe(cls_df, use_container_width=True) | |
st.divider() | |
# ------------------------------------------------------------------ | |
# Data Explorer (AG-Grid) β with text wrapping & interactive reordering | |
# ------------------------------------------------------------------ | |
st.subheader("Data Explorer") | |
# Filters | |
with st.expander("Filters", expanded=False): | |
sel_true = st.multiselect( | |
"Ground-truth labels β", sorted(df[y_true_col].unique()), | |
default=sorted(df[y_true_col].unique()), | |
) | |
sel_pred = st.multiselect( | |
"Predicted labels β", sorted(df[y_pred_col].unique()), | |
default=sorted(df[y_pred_col].unique()), | |
) | |
if prob_col in df.columns: | |
prob_rng = st.slider( | |
"Confidence range β", 0.0, 1.0, (0.0, 1.0), 0.01, key="prob_range" | |
) | |
else: | |
prob_rng = (0.0, 1.0) | |
# Apply filters | |
df_view = df[ | |
df[y_true_col].isin(sel_true) | |
& df[y_pred_col].isin(sel_pred) | |
& ( | |
(df[prob_col] >= prob_rng[0]) & (df[prob_col] <= prob_rng[1]) | |
if prob_col in df.columns | |
else True | |
) | |
].copy() | |
st.caption(f"Showing **{len(df_view):,}** rows after filtering.") | |
# Build AgGrid table with wrapping & movable columns | |
gb = GridOptionsBuilder.from_dataframe(df_view) | |
gb.configure_default_column( | |
editable=False, | |
filter=True, | |
sortable=True, | |
resizable=True, | |
wrapText=True, | |
autoHeight=True, | |
movable=True, # allow drag-and-drop | |
) | |
# Optional: give extra width to your free-text column | |
if "NOTE_OPERATORE" in df_view.columns: | |
gb.configure_column( | |
"NOTE_OPERATORE", | |
width=300, | |
minWidth=100, | |
maxWidth=600, | |
wrapText=True, | |
autoHeight=True, | |
) | |
gb.configure_selection("single", use_checkbox=True) | |
grid_opts = gb.build() | |
grid_opts["suppressMovableColumns"] = False | |
AgGrid( | |
df_view, | |
gridOptions=grid_opts, | |
enable_enterprise_modules=True, | |
height=400, | |
width="100%", | |
allow_unsafe_jscode=True, | |
update_mode="SELECTION_CHANGED", | |
) | |
# Selected-row details as before... | |
grid_resp = st.session_state.get("grid_response", None) | |
sel = grid_resp["selected_rows"] if grid_resp else [] | |
if sel: | |
row = sel[0] | |
st.markdown("### Row Details") | |
with st.expander(f"Document #: {row.get('NUMERO_DOCUMENTO','N/A')}", expanded=True): | |
st.write("**Ground-truth:**", row.get(y_true_col)) | |
st.write("**Predicted:**", row.get(y_pred_col)) | |
if prob_col in row: | |
st.write("**Confidence:**", row.get(prob_col)) | |
st.write("**Operator Notes:**") | |
st.write(row.get("NOTE_OPERATORE", "β")) | |
match_cols = [c for c in df.columns if c.startswith("MATCH") and not c.endswith("VALUE")] | |
if match_cols: | |
st.write("**Top Suggestions & Similarity**") | |
sim_df = pd.DataFrame( | |
{ | |
"Suggestion": [row.get(c) for c in match_cols], | |
"Similarity": [ | |
row.get(f"{c}_VALUE") if f"{c}_VALUE" in row else np.nan | |
for c in match_cols | |
], | |
} | |
) | |
st.table(sim_df) | |
if __name__ == "__main__": | |
main() | |