Spaces:
Sleeping
Sleeping
| """Rich-based live terminal table for displaying predictions.""" | |
| from __future__ import annotations | |
| from collections import deque | |
| from rich.live import Live | |
| from rich.table import Table | |
| from rich.text import Text | |
| from models.base import LABEL_EMOJI, LABEL_MEANING, CryPrediction | |
| from models.ensemble import compute_consensus | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _BAR_FULL = "β" | |
| _BAR_EMPTY = "β" | |
| _BAR_WIDTH = 5 | |
| def confidence_bar(value: float) -> str: | |
| """Render a 5-char Unicode bar for a 0.0β1.0 confidence value.""" | |
| filled = round(value * _BAR_WIDTH) | |
| return _BAR_FULL * filled + _BAR_EMPTY * (_BAR_WIDTH - filled) | |
| def format_confidence(value: float) -> str: | |
| """Bar + percentage string.""" | |
| pct = int(value * 100) | |
| return f"{confidence_bar(value)} {pct:>3}%" | |
| # ββ Display state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CryDisplay: | |
| """Manages a ``rich.live.Live`` context showing model predictions.""" | |
| def __init__(self, max_history: int = 5) -> None: | |
| self._window_count = 0 | |
| self._rms = 0.0 | |
| self._yamnet_status = "" | |
| self._source_label = "mic" | |
| self._predictions: list[CryPrediction] = [] | |
| self._history: deque[str] = deque(maxlen=max_history) | |
| self._live: Live | None = None | |
| # ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def start(self) -> Live: | |
| self._live = Live(self._build_table(), refresh_per_second=4) | |
| self._live.start() | |
| return self._live | |
| def stop(self) -> None: | |
| if self._live is not None: | |
| self._live.stop() | |
| self._live = None | |
| def update( | |
| self, | |
| predictions: list[CryPrediction], | |
| rms: float, | |
| source_label: str = "mic", | |
| is_silent: bool = False, | |
| ) -> None: | |
| self._window_count += 1 | |
| self._rms = rms | |
| self._source_label = source_label | |
| self._predictions = predictions | |
| # Update YAMNet status line | |
| yamnet_preds = [p for p in predictions if p.model_name == "YAMNet-detector"] | |
| if yamnet_preds: | |
| yp = yamnet_preds[0] | |
| icon = "β " if yp.label == "cry" else "β" | |
| self._yamnet_status = f"YAMNet: {icon} {yp.label.upper()} ({yp.confidence:.2f})" | |
| else: | |
| self._yamnet_status = "YAMNet: n/a" | |
| # History | |
| if is_silent: | |
| self._history.appendleft(f"#{self._window_count} [silence]") | |
| else: | |
| consensus = compute_consensus(predictions) | |
| tag = consensus if consensus else "β" | |
| self._history.appendleft(f"#{self._window_count} {tag}") | |
| if self._live is not None: | |
| self._live.update(self._build_table()) | |
| # ββ Table builder βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_table(self) -> Table: | |
| outer = Table( | |
| title=f"πΌ TotTalk Cry Eval β listening ({self._source_label}) (1s windows, 16 kHz)", | |
| title_style="bold cyan", | |
| show_header=False, | |
| show_edge=True, | |
| pad_edge=True, | |
| expand=True, | |
| ) | |
| outer.add_column(ratio=1) | |
| # Header row | |
| header = ( | |
| f" RMS: {self._rms:.4f} | Window #{self._window_count} " | |
| f"| {self._yamnet_status}" | |
| ) | |
| outer.add_row(Text(header, style="dim")) | |
| # Predictions table | |
| pred_table = Table(show_edge=False, expand=True, padding=(0, 1)) | |
| pred_table.add_column("Model", style="bold", min_width=18) | |
| pred_table.add_column("Label", min_width=14) | |
| pred_table.add_column("Confidence", min_width=12) | |
| pred_table.add_column("Latency", justify="right", min_width=10) | |
| for p in self._predictions: | |
| if p.error: | |
| pred_table.add_row( | |
| p.model_name, | |
| Text(f"β οΈ {p.error[:30]}", style="red"), | |
| "", | |
| "", | |
| ) | |
| else: | |
| pred_table.add_row( | |
| p.model_name, | |
| p.display_label, | |
| format_confidence(p.confidence), | |
| f"{p.latency_ms:.1f} ms", | |
| ) | |
| # Consensus row | |
| consensus = compute_consensus(self._predictions) | |
| if consensus: | |
| pred_table.add_row( | |
| Text("CONSENSUS", style="bold magenta"), | |
| Text(consensus, style="bold"), | |
| "", | |
| "", | |
| ) | |
| outer.add_row(pred_table) | |
| # History | |
| if self._history: | |
| hist_str = " ".join(self._history) | |
| outer.add_row(Text(f" Last detections: {hist_str}", style="dim")) | |
| # Cry meaning legend β show meaning for the consensus / top prediction | |
| shown_label = self._current_reason_label() | |
| if shown_label and shown_label in LABEL_MEANING: | |
| emoji = LABEL_EMOJI.get(shown_label, "") | |
| outer.add_row( | |
| Text( | |
| f" {emoji} {shown_label.replace('_', ' ').title()}: " | |
| f"{LABEL_MEANING[shown_label]}", | |
| style="italic yellow", | |
| ) | |
| ) | |
| return outer | |
| def _current_reason_label(self) -> str | None: | |
| """Return the most relevant reason label from the current predictions.""" | |
| for p in self._predictions: | |
| if p.model_name == "YAMNet-detector": | |
| continue | |
| if p.error or p.label in ("no_cry", "timeout", "error"): | |
| continue | |
| return p.label | |
| return None | |