| |
| """Plot debug_pass{N}.npz files saved by predict.py when MIMI_DEBUG=1.""" |
| import sys |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| _HERE = Path(__file__).resolve().parent |
|
|
| |
| CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"] |
| CLASS_COLORS = ["#9E9E9E", "#FF9800", "#8BC34A", "#03A9F4", "#2196F3"] |
| IDX_USER = 4 |
|
|
|
|
| def plot_npz(path: Path) -> None: |
| d = np.load(path) |
| subj = d["subj"] |
| other = d["other"] |
| floor = d["floor"].astype(np.float32) |
| if "probs" in d.files: |
| probs = d["probs"] |
| else: |
| p_user = d["p_user"] |
| probs = np.zeros((len(p_user), 5), dtype=np.float32) |
| probs[:, IDX_USER] = p_user |
| thr = float(d["threshold"]) |
| sr = int(d["sr"]) |
| hz = int(d["frame_rate"]) |
|
|
| n_steps = len(floor) |
| duration = n_steps / hz |
| t_wav = np.arange(len(subj)) / sr |
| t_frame = np.arange(n_steps) / hz |
|
|
| fig, axes = plt.subplots(4, 1, figsize=(max(20, duration / 5), 10), |
| gridspec_kw={"hspace": 0.05, "height_ratios": [1, 1, 2, 1]}) |
|
|
| |
| for ax, wav, label in zip(axes[:2], [subj, other], ["subject", "other"]): |
| ax.plot(t_wav, wav, lw=0.3, color="#555", alpha=0.8) |
| ax.set_ylabel(label, fontsize=8) |
| ax.set_xlim(0, duration) |
| ax.tick_params(labelbottom=False) |
|
|
| |
| ax_prob = axes[2] |
| for idx, (name, color) in enumerate(zip(CLASS_NAMES, CLASS_COLORS)): |
| lw = 1.2 if idx == IDX_USER else 0.7 |
| alpha = 0.9 if idx == IDX_USER else 0.7 |
| ax_prob.plot(t_frame, probs[:, idx], lw=lw, color=color, alpha=alpha, label=name) |
| ax_prob.fill_between(t_frame, probs[:, IDX_USER], alpha=0.1, color=CLASS_COLORS[IDX_USER]) |
| ax_prob.axhline(thr, color="red", lw=0.8, ls="--", label=f"threshold={thr}") |
| ax_prob.set_ylim(-0.05, 1.05) |
| ax_prob.set_ylabel("class probs", fontsize=8) |
| ax_prob.legend(fontsize=7, loc="upper right", ncol=3) |
| ax_prob.set_xlim(0, duration) |
| ax_prob.tick_params(labelbottom=False) |
|
|
| |
| axes[3].step(t_frame, floor, where="post", lw=0.8, color="#4CAF50") |
| axes[3].fill_between(t_frame, floor, step="post", alpha=0.2, color="#4CAF50") |
| axes[3].set_ylim(-0.1, 1.1) |
| axes[3].set_ylabel("floor bit", fontsize=8) |
| axes[3].set_xlabel("Time (s)", fontsize=9) |
| axes[3].set_xlim(0, duration) |
|
|
| for ax in axes: |
| for vl in np.arange(1, int(duration) + 1): |
| ax.axvline(vl, color="#ddd", lw=0.4, zorder=0) |
|
|
| axes[0].set_title(f"mimi_endpointer {path.stem} threshold={thr}", fontsize=9, loc="left", pad=3) |
| out = path.with_suffix(".png") |
| fig.savefig(out, dpi=100, bbox_inches="tight") |
| plt.close(fig) |
| print(f"saved → {out}") |
|
|
|
|
| if __name__ == "__main__": |
| npz_files = sorted(_HERE.glob("debug_pass*.npz")) |
| if not npz_files: |
| print("No debug_pass*.npz files found. Run with MIMI_DEBUG=1 first.") |
| sys.exit(1) |
| for f in npz_files: |
| plot_npz(f) |
|
|